diff --git a/diskann-benchmark/example/multi-vector-3way.json b/diskann-benchmark/example/multi-vector-3way.json new file mode 100644 index 000000000..3a036d60e --- /dev/null +++ b/diskann-benchmark/example/multi-vector-3way.json @@ -0,0 +1,47 @@ +{ + "search_directories": [], + "jobs": [ + { + "type": "multi-vector-op", + "content": { + "element_type": "float32", + "isa": "reference", + "runs": [ + { "num_query_vectors": 32, "num_doc_vectors": 256, "dim": 16, "loops_per_measurement": 200, "num_measurements": 50 }, + { "num_query_vectors": 32, "num_doc_vectors": 256, "dim": 64, "loops_per_measurement": 100, "num_measurements": 50 }, + { "num_query_vectors": 32, "num_doc_vectors": 256, "dim": 128, "loops_per_measurement": 50, "num_measurements": 50 }, + { "num_query_vectors": 32, "num_doc_vectors": 256, "dim": 256, "loops_per_measurement": 25, "num_measurements": 50 }, + { "num_query_vectors": 32, "num_doc_vectors": 256, "dim": 512, "loops_per_measurement": 12, "num_measurements": 50 } + ] + } + }, + { + "type": "multi-vector-op", + "content": { + "element_type": "float32", + "isa": "x86-64-v3", + "runs": [ + { "num_query_vectors": 32, "num_doc_vectors": 256, "dim": 16, "loops_per_measurement": 200, "num_measurements": 50 }, + { "num_query_vectors": 32, "num_doc_vectors": 256, "dim": 64, "loops_per_measurement": 100, "num_measurements": 50 }, + { "num_query_vectors": 32, "num_doc_vectors": 256, "dim": 128, "loops_per_measurement": 50, "num_measurements": 50 }, + { "num_query_vectors": 32, "num_doc_vectors": 256, "dim": 256, "loops_per_measurement": 25, "num_measurements": 50 }, + { "num_query_vectors": 32, "num_doc_vectors": 256, "dim": 512, "loops_per_measurement": 12, "num_measurements": 50 } + ] + } + }, + { + "type": "multi-vector-op", + "content": { + "element_type": "float32", + "isa": "x86-64-v3-staged", + "runs": [ + { "num_query_vectors": 32, "num_doc_vectors": 256, "dim": 16, "loops_per_measurement": 200, "num_measurements": 50 }, + { "num_query_vectors": 32, "num_doc_vectors": 256, "dim": 64, "loops_per_measurement": 100, "num_measurements": 50 }, + { "num_query_vectors": 32, "num_doc_vectors": 256, "dim": 128, "loops_per_measurement": 50, "num_measurements": 50 }, + { "num_query_vectors": 32, "num_doc_vectors": 256, "dim": 256, "loops_per_measurement": 25, "num_measurements": 50 }, + { "num_query_vectors": 32, "num_doc_vectors": 256, "dim": 512, "loops_per_measurement": 12, "num_measurements": 50 } + ] + } + } + ] +} diff --git a/diskann-benchmark/example/multi-vector-quant.json b/diskann-benchmark/example/multi-vector-quant.json new file mode 100644 index 000000000..ffc4131ac --- /dev/null +++ b/diskann-benchmark/example/multi-vector-quant.json @@ -0,0 +1,20 @@ +{ + "search_directories": [], + "jobs": [ + { + "type": "multi-vector-quant-op", + "content": { + "runs": [ + { "num_query_vectors": 8, "num_doc_vectors": 32, "dim": 128, "loops_per_measurement": 500, "num_measurements": 50 }, + { "num_query_vectors": 16, "num_doc_vectors": 64, "dim": 256, "loops_per_measurement": 100, "num_measurements": 50 }, + { "num_query_vectors": 32, "num_doc_vectors": 128, "dim": 384, "loops_per_measurement": 20, "num_measurements": 50 }, + { "num_query_vectors": 32, "num_doc_vectors": 16, "dim": 256, "loops_per_measurement": 200, "num_measurements": 50 }, + { "num_query_vectors": 32, "num_doc_vectors": 1250, "dim": 128, "loops_per_measurement": 10, "num_measurements": 50 }, + { "num_query_vectors": 64, "num_doc_vectors": 1250, "dim": 512, "loops_per_measurement": 2, "num_measurements": 50 }, + { "num_query_vectors": 64, "num_doc_vectors": 32, "dim": 128, "loops_per_measurement": 200, "num_measurements": 50 }, + { "num_query_vectors": 32, "num_doc_vectors": 32, "dim": 512, "loops_per_measurement": 50, "num_measurements": 50 } + ] + } + } + ] +} diff --git a/diskann-benchmark/example/multi-vector-staged.json b/diskann-benchmark/example/multi-vector-staged.json new file mode 100644 index 000000000..8cb5798dc --- /dev/null +++ b/diskann-benchmark/example/multi-vector-staged.json @@ -0,0 +1,41 @@ +{ + "search_directories": [], + "jobs": [ + { + "type": "multi-vector-op", + "content": { + "element_type": "float32", + "isa": "x86-64-v3", + "runs": [ + { "num_query_vectors": 8, "num_doc_vectors": 32, "dim": 128, "loops_per_measurement": 500, "num_measurements": 50 }, + { "num_query_vectors": 16, "num_doc_vectors": 64, "dim": 256, "loops_per_measurement": 100, "num_measurements": 50 }, + { "num_query_vectors": 32, "num_doc_vectors": 128, "dim": 384, "loops_per_measurement": 20, "num_measurements": 50 }, + { "num_query_vectors": 32, "num_doc_vectors": 16, "dim": 256, "loops_per_measurement": 200, "num_measurements": 50 }, + { "num_query_vectors": 64, "num_doc_vectors": 32, "dim": 264, "loops_per_measurement": 50, "num_measurements": 50 }, + { "num_query_vectors": 32, "num_doc_vectors": 1250, "dim": 128, "loops_per_measurement": 10, "num_measurements": 50 }, + { "num_query_vectors": 64, "num_doc_vectors": 1250, "dim": 512, "loops_per_measurement": 2, "num_measurements": 50 }, + { "num_query_vectors": 64, "num_doc_vectors": 32, "dim": 128, "loops_per_measurement": 200, "num_measurements": 50 }, + { "num_query_vectors": 32, "num_doc_vectors": 32, "dim": 512, "loops_per_measurement": 50, "num_measurements": 50 } + ] + } + }, + { + "type": "multi-vector-op", + "content": { + "element_type": "float32", + "isa": "x86-64-v3-staged", + "runs": [ + { "num_query_vectors": 8, "num_doc_vectors": 32, "dim": 128, "loops_per_measurement": 500, "num_measurements": 50 }, + { "num_query_vectors": 16, "num_doc_vectors": 64, "dim": 256, "loops_per_measurement": 100, "num_measurements": 50 }, + { "num_query_vectors": 32, "num_doc_vectors": 128, "dim": 384, "loops_per_measurement": 20, "num_measurements": 50 }, + { "num_query_vectors": 32, "num_doc_vectors": 16, "dim": 256, "loops_per_measurement": 200, "num_measurements": 50 }, + { "num_query_vectors": 64, "num_doc_vectors": 32, "dim": 264, "loops_per_measurement": 50, "num_measurements": 50 }, + { "num_query_vectors": 32, "num_doc_vectors": 1250, "dim": 128, "loops_per_measurement": 10, "num_measurements": 50 }, + { "num_query_vectors": 64, "num_doc_vectors": 1250, "dim": 512, "loops_per_measurement": 2, "num_measurements": 50 }, + { "num_query_vectors": 64, "num_doc_vectors": 32, "dim": 128, "loops_per_measurement": 200, "num_measurements": 50 }, + { "num_query_vectors": 32, "num_doc_vectors": 32, "dim": 512, "loops_per_measurement": 50, "num_measurements": 50 } + ] + } + } + ] +} diff --git a/diskann-benchmark/src/inputs/multi_vector.rs b/diskann-benchmark/src/inputs/multi_vector.rs index c74f9d232..5b2aaa833 100644 --- a/diskann-benchmark/src/inputs/multi_vector.rs +++ b/diskann-benchmark/src/inputs/multi_vector.rs @@ -26,6 +26,9 @@ pub(crate) enum BenchIsa { #[serde(rename = "x86-64-v3")] #[allow(non_camel_case_types)] X86_64_V3, + #[serde(rename = "x86-64-v3-staged")] + #[allow(non_camel_case_types)] + X86_64_V3_Staged, Neon, Scalar, Reference, @@ -37,6 +40,7 @@ impl std::fmt::Display for BenchIsa { let st = match self { Self::X86_64_V4 => "x86-64-v4", Self::X86_64_V3 => "x86-64-v3", + Self::X86_64_V3_Staged => "x86-64-v3-staged", Self::Neon => "neon", Self::Scalar => "scalar", Self::Reference => "reference", @@ -51,6 +55,7 @@ impl From for MaxSimIsa { match b { BenchIsa::X86_64_V4 => MaxSimIsa::X86_64_V4, BenchIsa::X86_64_V3 => MaxSimIsa::X86_64_V3, + BenchIsa::X86_64_V3_Staged => MaxSimIsa::X86_64_V3_Staged, BenchIsa::Neon => MaxSimIsa::Neon, BenchIsa::Scalar => MaxSimIsa::Scalar, BenchIsa::Reference => MaxSimIsa::Reference, @@ -149,3 +154,80 @@ impl std::fmt::Display for MultiVectorOp { Ok(()) } } + +/////////////////////////////// +// Multi-Vector Quantized Op // +/////////////////////////////// + +/// A 4-bit MinMax **quantized** multi-vector MaxSim A/B benchmark job: the +/// experimental staged integer kernel vs the scalar `MinMaxKernel` reference, +/// at identical shapes and quantization. +/// +/// The element type is implicitly f32 input → 4-bit MinMax codes, and the ISA is +/// fixed to V3/AVX2 (the only quantized staged kernel), so neither is a JSON +/// field. x86_64-only, like the kernel it drives. +#[cfg(all(feature = "multi-vector", target_arch = "x86_64"))] +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct MultiVectorQuantOp { + pub(crate) runs: Vec, +} + +#[cfg(all(feature = "multi-vector", target_arch = "x86_64"))] +impl MultiVectorQuantOp { + pub(crate) const fn tag() -> &'static str { + "multi-vector-quant-op" + } +} + +#[cfg(all(feature = "multi-vector", target_arch = "x86_64"))] +impl Input for MultiVectorQuantOp { + type Raw = Self; + + fn tag() -> &'static str { + Self::tag() + } + + fn from_raw(raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result { + Ok(raw) + } + + fn serialize(&self) -> anyhow::Result { + Ok(serde_json::to_value(self)?) + } + + fn example() -> Self { + const NUM_DOC_VECTORS: NonZeroUsize = NonZeroUsize::new(64).unwrap(); + const DIM: NonZeroUsize = NonZeroUsize::new(128).unwrap(); + const LOOPS_PER_MEASUREMENT: NonZeroUsize = NonZeroUsize::new(50).unwrap(); + const NUM_MEASUREMENTS: NonZeroUsize = NonZeroUsize::new(20).unwrap(); + + let runs = vec![ + Run { + num_query_vectors: NonZeroUsize::new(32).unwrap(), + num_doc_vectors: NUM_DOC_VECTORS, + dim: DIM, + loops_per_measurement: LOOPS_PER_MEASUREMENT, + num_measurements: NUM_MEASUREMENTS, + }, + Run { + num_query_vectors: NonZeroUsize::new(64).unwrap(), + num_doc_vectors: NUM_DOC_VECTORS, + dim: DIM, + loops_per_measurement: LOOPS_PER_MEASUREMENT, + num_measurements: NUM_MEASUREMENTS, + }, + ]; + + Self { runs } + } +} + +#[cfg(all(feature = "multi-vector", target_arch = "x86_64"))] +impl std::fmt::Display for MultiVectorQuantOp { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "Multi-Vector Quantized Operation (4-bit MinMax)\n")?; + write_field!(f, "tag", Self::tag())?; + write_field!(f, "number of runs", self.runs.len())?; + Ok(()) + } +} diff --git a/diskann-benchmark/src/multi_vector/mod.rs b/diskann-benchmark/src/multi_vector/mod.rs index dfad330af..a01285ba1 100644 --- a/diskann-benchmark/src/multi_vector/mod.rs +++ b/diskann-benchmark/src/multi_vector/mod.rs @@ -25,9 +25,15 @@ cfg_if::cfg_if! { if #[cfg(feature = "multi-vector")] { mod driver; mod kernels; + // The quantized A/B op drives the V3-only staged integer kernel. + #[cfg(target_arch = "x86_64")] + mod quant; pub(super) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { - kernels::register(registry) + kernels::register(registry)?; + #[cfg(target_arch = "x86_64")] + quant::register(registry)?; + Ok(()) } } else { crate::utils::stub_impl!("multi-vector", inputs::multi_vector::MultiVectorOp); diff --git a/diskann-benchmark/src/multi_vector/quant.rs b/diskann-benchmark/src/multi_vector/quant.rs new file mode 100644 index 000000000..1d09ca88b --- /dev/null +++ b/diskann-benchmark/src/multi_vector/quant.rs @@ -0,0 +1,268 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! A/B benchmark for **4-bit MinMax quantized** multi-vector MaxSim: the +//! experimental *staged integer* kernel (block-transposed `i16` query + `u8` +//! doc codes, `vpmaddwd` accumulation, metadata postprocess) vs the scalar +//! [`MinMaxKernel`] reference — at identical shapes and identical quantization. +//! +//! Both paths consume the *same* random f32 multi-vectors quantized to 4-bit +//! MinMax (Null transform, scale 1.0), so the comparison isolates the distance +//! kernel. The build / quantize cost is excluded from the timing. +//! +//! x86_64 (V3/AVX2) only — the quantized staged kernel has no other backend. + +use std::io::Write; +use std::num::NonZeroUsize; + +use diskann_benchmark_runner::{ + benchmark::{FailureScore, MatchScore}, + utils::{fmt::Table, percentiles, MicroSeconds}, + Benchmark, Checkpoint, Output, Registry, +}; +use diskann_quantization::algorithms::transforms::NullTransform; +use diskann_quantization::algorithms::Transform; +use diskann_quantization::minmax::{MinMaxMeta, MinMaxQuantizer}; +use diskann_quantization::multi_vector::distance::{QuantStagedDocs, QuantStagedQuery}; +use diskann_quantization::multi_vector::{Defaulted, Mat, MatRef, MaxSim, QueryMatRef, Standard}; +use diskann_quantization::num::Positive; +use diskann_quantization::CompressInto; +use diskann_utils::ReborrowMut; +use diskann_vector::DistanceFunctionMut; +use serde::{Deserialize, Serialize}; + +use super::driver::Data; +use crate::inputs::multi_vector::{MultiVectorQuantOp, Run}; +use crate::utils::DisplayWrapper; + +// ───────────────────────────────────────────────────────────────────────── +// Kernel. +// ───────────────────────────────────────────────────────────────────────── + +#[derive(Debug)] +pub(super) struct QuantKernel; + +impl QuantKernel { + pub(super) const fn new() -> Self { + Self + } +} + +impl Benchmark for QuantKernel { + type Input = MultiVectorQuantOp; + type Output = Vec; + + fn try_match(&self, _from: &MultiVectorQuantOp) -> Result { + // The staged integer kernel requires AVX2 (V3). + if QuantStagedQuery::is_supported() { + Ok(MatchScore(0)) + } else { + Err(FailureScore(0)) + } + } + + fn run( + &self, + input: &MultiVectorQuantOp, + _: Checkpoint<'_>, + mut output: &mut dyn Output, + ) -> anyhow::Result { + writeln!(output, "{}", input)?; + let mut results = Vec::with_capacity(input.runs.len()); + for run in input.runs.iter() { + results.push(run_ab(run)?); + } + writeln!(output, "\n\n{}", DisplayWrapper(&*results))?; + Ok(results) + } + + fn description( + &self, + f: &mut std::fmt::Formatter<'_>, + input: Option<&MultiVectorQuantOp>, + ) -> std::fmt::Result { + match input { + None => writeln!(f, "- 4-bit MinMax quantized staged MaxSim (V3/AVX2)")?, + Some(_) => { + if !QuantStagedQuery::is_supported() { + writeln!(f, "\n - AVX2 (V3) unavailable on this CPU")?; + } + } + } + Ok(()) + } +} + +// ───────────────────────────────────────────────────────────────────────── +// A/B timing. +// ───────────────────────────────────────────────────────────────────────── + +/// Quantize an f32 multi-vector to 4-bit MinMax (Null transform, scale 1.0) — +/// the same quantizer both paths share so the codes + metadata are identical. +fn quantize(input: MatRef<'_, Standard>) -> Mat> { + let dim = input.vector_dim(); + let n = input.num_vectors(); + let q = MinMaxQuantizer::new( + Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())), + Positive::new(1.0).unwrap(), + ); + let mut out: Mat> = Mat::new(MinMaxMeta::new(n, dim), Defaulted).unwrap(); + q.compress_into(input, out.reborrow_mut()).unwrap(); + out +} + +/// Run `f` `loops_per_measurement` times per measurement, `num_measurements` +/// times, returning the per-measurement latencies and their percentiles. +fn measure(run: &Run, mut f: impl FnMut()) -> Series { + let mut latencies = Vec::with_capacity(run.num_measurements.get()); + for _ in 0..run.num_measurements.get() { + let start = std::time::Instant::now(); + for _ in 0..run.loops_per_measurement.get() { + f(); + } + latencies.push(start.elapsed().into()); + } + let percentiles = percentiles::compute_percentiles(&mut latencies).unwrap(); + Series { + latencies, + percentiles, + } +} + +/// Build both kernels for one shape and time them (build / quantize excluded). +fn run_ab(run: &Run) -> anyhow::Result { + let data = Data::::new(run)?; + + // Path A — staged integer kernel (quantizes internally at build time). + let mut query = QuantStagedQuery::build(data.queries.as_view()) + .ok_or_else(|| anyhow::anyhow!("AVX2 (V3) unavailable for the staged quantized kernel"))?; + let docs = QuantStagedDocs::build(data.docs.as_view()); + + // Path B — scalar MinMax reference over the same quantization. + let q_ref = quantize(data.queries.as_view()); + let d_ref = quantize(data.docs.as_view()); + + let nq = run.num_query_vectors.get(); + let mut scores = vec![0.0f32; nq]; + + // Launder BOTH the inputs and the output through `black_box` each iteration. + // Output-only `black_box` is not enough: the reference chain is `#[inline(always)]` + // end-to-end with loop-invariant inputs, so the optimizer could hoist/elide it out + // of the measured loop (the staged path is an opaque cross-crate call and cannot be), + // making the A/B asymmetric. Laundering the inputs forces both paths to re-run the + // full per-call work every iteration. + let staged = measure(run, || { + let docs = std::hint::black_box(&docs); + query.compute_max_sim(docs, &mut scores); + std::hint::black_box(&mut scores); + }); + + let reference = measure(run, || { + let q_ref = std::hint::black_box(&q_ref); + let d_ref = std::hint::black_box(&d_ref); + let query_ref: QueryMatRef<_> = q_ref.as_view().into(); + MaxSim::new(&mut scores).evaluate(query_ref, d_ref.as_view()); + std::hint::black_box(&mut scores); + }); + + Ok(QuantRunResult { + run: run.clone(), + staged, + reference, + }) +} + +// ───────────────────────────────────────────────────────────────────────── +// Result types. +// ───────────────────────────────────────────────────────────────────────── + +/// One timed series (per-measurement latencies + percentiles). +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct Series { + latencies: Vec, + percentiles: percentiles::Percentiles, +} + +impl Series { + /// Minimum latency, in microseconds. + fn min_us(&self) -> f64 { + self.latencies + .iter() + .min() + .copied() + .unwrap_or(MicroSeconds::new(u64::MAX)) + .as_f64() + } +} + +/// Staged-vs-reference result for one shape. +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct QuantRunResult { + pub(super) run: Run, + pub(super) staged: Series, + pub(super) reference: Series, +} + +impl QuantRunResult { + fn computations(&self) -> f64 { + (self.run.num_query_vectors.get() + * self.run.num_doc_vectors.get() + * self.run.loops_per_measurement.get()) as f64 + } +} + +impl std::fmt::Display for DisplayWrapper<'_, [QuantRunResult]> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.is_empty() { + return Ok(()); + } + + writeln!( + f, + "ns/IP = min time per (query, doc) inner-product call; \ + Speedup = reference / staged (>1 ⇒ staged faster)" + )?; + + let header = [ + "Q", + "D", + "Dim", + "Staged (ns/IP)", + "Reference (ns/IP)", + "Speedup", + ]; + let mut table = Table::new(header, self.len()); + + self.iter().enumerate().for_each(|(row, r)| { + let comps = r.computations(); + let staged = r.staged.min_us() / comps * 1000.0; + let reference = r.reference.min_us() / comps * 1000.0; + let speedup = if staged > 0.0 { + reference / staged + } else { + 0.0 + }; + + let mut row = table.row(row); + row.insert(r.run.num_query_vectors, 0); + row.insert(r.run.num_doc_vectors, 1); + row.insert(r.run.dim, 2); + row.insert(format!("{:.3}", staged), 3); + row.insert(format!("{:.3}", reference), 4); + row.insert(format!("{:.2}x", speedup), 5); + }); + + table.fmt(f) + } +} + +// ───────────────────────────────────────────────────────────────────────── +// Registration. +// ───────────────────────────────────────────────────────────────────────── + +pub(super) fn register(registry: &mut Registry) -> anyhow::Result<()> { + registry.register("multi-vector-quant-op", QuantKernel::new())?; + Ok(()) +} diff --git a/diskann-quantization/src/multi_vector/distance/factory.rs b/diskann-quantization/src/multi_vector/distance/factory.rs index 5dcd4b8cd..65a69691a 100644 --- a/diskann-quantization/src/multi_vector/distance/factory.rs +++ b/diskann-quantization/src/multi_vector/distance/factory.rs @@ -18,6 +18,8 @@ use super::isa::{MaxSimIsa, NotSupported}; use super::kernel::{Erase, MaxSimKernel}; use super::kernels::f16::F16Entry; use super::kernels::f32::F32Kernel; +#[cfg(target_arch = "x86_64")] +use super::kernels::staged::{F32StagedScratch, StagedF32Kernel, StagedRun}; use super::max_sim::{MaxSim, MaxSimError}; use crate::multi_vector::distance::QueryMatRef; use crate::multi_vector::{BlockTransposed, BlockTransposedRef, Mat, MatRef, Standard}; @@ -266,6 +268,92 @@ impl> } } +// ───────────────────────────────────────────────────────────────────────── +// Staged kernel (experimental) — selected by MaxSimIsa::X86_64_V3_Staged. +// Coexists with the fused `Prepared` path above for A/B benchmarking. +// ───────────────────────────────────────────────────────────────────────── + +/// Counterpart to [`Prepared`] for the staged f32 kernel. Owns a +/// [`F32StagedScratch`] (reset-arena + reused `state`) behind a `RefCell`: because +/// `compute_max_sim` takes `&self` and [`MaxSimKernel`] is `Send` but not `Sync`, +/// interior mutability lets each call reset+reuse the scratch with **zero heap +/// allocation** in steady state. +#[cfg(target_arch = "x86_64")] +#[derive(Debug)] +struct PreparedStaged { + arch: A, + prepared: BlockTransposed, + scratch: std::cell::RefCell, +} + +#[cfg(target_arch = "x86_64")] +impl MaxSimKernel for PreparedStaged +where + A: Architecture, + StagedF32Kernel: for<'a> diskann_wide::arch::Target3< + A, + (), + BlockTransposedRef<'a, f32, GROUP>, + MatRef<'a, Standard>, + StagedRun<'a>, + >, +{ + fn nrows(&self) -> usize { + self.prepared.nrows() + } + + fn compute_max_sim( + &self, + doc: MatRef<'_, Standard>, + scores: &mut [f32], + ) -> Result<(), MaxSimError> { + if scores.len() != self.nrows() { + return Err(MaxSimError::InvalidBufferLength(scores.len(), self.nrows())); + } + if doc.num_vectors() == 0 { + scores.fill(f32::MAX); + return Ok(()); + } + let padded = self.prepared.padded_nrows(); + let nrows = self.prepared.nrows(); + // Reset + reuse the owned arena/state scratch (no per-call allocation). + self.scratch.borrow_mut().run(padded, |state, alloc| { + self.arch.run3( + StagedF32Kernel::, + self.prepared.reborrow(), + doc, + StagedRun { + state: &mut state[..], + alloc, + }, + ); + // Distance = negated max inner product (matches the fused path). + for (dst, &src) in scores.iter_mut().zip(state[..nrows].iter()) { + *dst = -src; + } + }); + Ok(()) + } +} + +#[cfg(target_arch = "x86_64")] +struct BuildAndEraseStaged(E); + +#[cfg(target_arch = "x86_64")] +impl> diskann_wide::arch::Target1>> + for BuildAndEraseStaged +{ + fn run(self, arch: V3, query: MatRef<'_, Standard>) -> E::Output { + let prepared = BlockTransposed::::from_matrix_view(query.as_matrix_view()); + let scratch = std::cell::RefCell::new(F32StagedScratch::new(prepared.padded_nrows())); + self.0.erase(PreparedStaged { + arch, + prepared, + scratch, + }) + } +} + // ───────────────────────────────────────────────────────────────────────── // MaxSimElement — sealed trait gating accepted element types. // ───────────────────────────────────────────────────────────────────────── @@ -325,11 +413,21 @@ impl MaxSimElement for f32 { })?; Ok(arch.run1(BuildAndErase(erase), query)) } + #[cfg(target_arch = "x86_64")] + MaxSimIsa::X86_64_V3_Staged => { + let arch = V3::new_checked().ok_or(NotSupported { + isa, + reason: "AVX2/FMA unavailable on this CPU", + })?; + Ok(arch.run1(BuildAndEraseStaged(erase), query)) + } #[cfg(not(target_arch = "x86_64"))] - MaxSimIsa::X86_64_V3 | MaxSimIsa::X86_64_V4 => Err(NotSupported { - isa, - reason: "x86_64 target only", - }), + MaxSimIsa::X86_64_V3 | MaxSimIsa::X86_64_V4 | MaxSimIsa::X86_64_V3_Staged => { + Err(NotSupported { + isa, + reason: "x86_64 target only", + }) + } #[cfg(target_arch = "aarch64")] MaxSimIsa::Neon => { let arch = Neon::new_checked().ok_or(NotSupported { @@ -394,6 +492,10 @@ impl MaxSimElement for half::f16 { isa, reason: "aarch64 target only", }), + MaxSimIsa::X86_64_V3_Staged => Err(NotSupported { + isa, + reason: "x86-64-v3-staged supports f32 only", + }), MaxSimIsa::Reference => Ok(erase.erase(ReferenceKernel::::new(query))), } } @@ -538,6 +640,48 @@ mod tests { assert_eq!(kernel.nrows(), 5); } + /// Reusing one staged f32 `MaxSimKernel` across multiple `compute_max_sim` + /// calls (different doc counts) must match the reference each time — the + /// regression guard for the `PreparedStaged`-owned reset-arena + /// (`F32StagedScratch`): each call rewinds + reuses the scratch, so a stale or + /// aliased buffer would corrupt a later call. + #[cfg(target_arch = "x86_64")] + #[test] + fn staged_f32_arena_reuse() { + if diskann_wide::arch::x86_64::V3::new_checked().is_none() { + return; // No AVX2 on this host; the staged f32 path cannot build. + } + + const NQ: usize = 17; // exercises the A-panel row padding (17 -> 32) + const DIM: usize = 96; + let q_data = make_test_data::(NQ * DIM, DIM, DIM / 2); + let query = make_mat(&q_data, NQ, DIM); + + let staged = build_max_sim::(MaxSimIsa::X86_64_V3_Staged, query, BoxErase).unwrap(); + let reference = build_max_sim::(MaxSimIsa::Reference, query, BoxErase).unwrap(); + + // Distinct doc counts (multi-tile, single panel, remainder, one) reusing the + // same kernel — the arena is reset, never reallocated, between calls. + for (call, &nd) in [200usize, 3, 33, 1].iter().enumerate() { + let d_data = make_test_data::(nd * DIM, DIM, DIM + call); + let doc = make_mat(&d_data, nd, DIM); + + let mut got = vec![0.0f32; NQ]; + staged.compute_max_sim(doc, &mut got).unwrap(); + let mut want = vec![0.0f32; NQ]; + reference.compute_max_sim(doc, &mut want).unwrap(); + + for i in 0..NQ { + assert!( + (got[i] - want[i]).abs() <= 1e-4 * want[i].abs().max(1.0), + "call {call} (nd={nd}) row {i}: reused staged f32 {} != reference {}", + got[i], + want[i], + ); + } + } + } + fn check_size_mismatch(label: &str) where T: MaxSimElement + FromF32, diff --git a/diskann-quantization/src/multi_vector/distance/isa.rs b/diskann-quantization/src/multi_vector/distance/isa.rs index d295438bc..29ec6dffd 100644 --- a/diskann-quantization/src/multi_vector/distance/isa.rs +++ b/diskann-quantization/src/multi_vector/distance/isa.rs @@ -24,6 +24,10 @@ pub enum MaxSimIsa { X86_64_V3, /// x86_64 AVX-512. X86_64_V4, + /// Experimental staged-pipeline kernel (x86_64 AVX2+FMA). Coexists with + /// [`Self::X86_64_V3`] for A/B benchmarking; produces the same results via a + /// different (kernel → postprocess → reducer) micro-kernel structure. + X86_64_V3_Staged, /// AArch64 Neon. Neon, /// Non-SIMD reference fallback. Slow; serves as a correctness baseline. @@ -41,8 +45,10 @@ impl MaxSimIsa { Self::X86_64_V3 => diskann_wide::arch::x86_64::V3::new_checked().is_some(), #[cfg(target_arch = "x86_64")] Self::X86_64_V4 => diskann_wide::arch::x86_64::V4::new_checked().is_some(), + #[cfg(target_arch = "x86_64")] + Self::X86_64_V3_Staged => diskann_wide::arch::x86_64::V3::new_checked().is_some(), #[cfg(not(target_arch = "x86_64"))] - Self::X86_64_V3 | Self::X86_64_V4 => false, + Self::X86_64_V3 | Self::X86_64_V4 | Self::X86_64_V3_Staged => false, #[cfg(target_arch = "aarch64")] Self::Neon => diskann_wide::arch::aarch64::Neon::new_checked().is_some(), #[cfg(not(target_arch = "aarch64"))] @@ -58,6 +64,7 @@ impl std::fmt::Display for MaxSimIsa { Self::Scalar => "scalar", Self::X86_64_V3 => "x86-64-v3", Self::X86_64_V4 => "x86-64-v4", + Self::X86_64_V3_Staged => "x86-64-v3-staged", Self::Neon => "neon", Self::Reference => "reference", }; diff --git a/diskann-quantization/src/multi_vector/distance/kernel.rs b/diskann-quantization/src/multi_vector/distance/kernel.rs index b292def54..127c7f99e 100644 --- a/diskann-quantization/src/multi_vector/distance/kernel.rs +++ b/diskann-quantization/src/multi_vector/distance/kernel.rs @@ -6,7 +6,13 @@ use crate::multi_vector::{MatRef, MaxSimError, Standard}; /// Object-safe interface for computing per-query MaxSim scores. -pub trait MaxSimKernel: Send + Sync + std::fmt::Debug { +/// +/// `Send` (not `Sync`): a built kernel can be **moved** to a worker thread that +/// owns it (the "each search thread owns its distance computer" model), but is +/// not required to be shared by reference across threads. Dropping `Sync` is what +/// lets a kernel own interior-mutable per-call scratch (e.g. the staged f32 +/// kernel's reset-arena) under a `&self` method. +pub trait MaxSimKernel: Send + std::fmt::Debug { /// Number of query rows whose scores this kernel produces. fn nrows(&self) -> usize; diff --git a/diskann-quantization/src/multi_vector/distance/kernels/mod.rs b/diskann-quantization/src/multi_vector/distance/kernels/mod.rs index 55108698d..61182315a 100644 --- a/diskann-quantization/src/multi_vector/distance/kernels/mod.rs +++ b/diskann-quantization/src/multi_vector/distance/kernels/mod.rs @@ -16,8 +16,16 @@ pub(super) mod f16; pub(super) mod f32; mod layouts; mod reduce; +// The staged kernel is V3 (x86_64) only; gate the whole module so its support +// code isn't dead on other architectures. +#[cfg(target_arch = "x86_64")] +pub(super) mod staged; mod tiled_reduce; +// Re-export the quantized staged kernel's public POC entry (x86_64 only). +#[cfg(target_arch = "x86_64")] +pub use staged::{QuantStagedDocs, QuantStagedQuery}; + // ── Tile budget ────────────────────────────────────────────────── /// Cache budgets fed to the tile planner. diff --git a/diskann-quantization/src/multi_vector/distance/kernels/staged/README.md b/diskann-quantization/src/multi_vector/distance/kernels/staged/README.md new file mode 100644 index 000000000..61c1389a8 --- /dev/null +++ b/diskann-quantization/src/multi_vector/distance/kernels/staged/README.md @@ -0,0 +1,91 @@ +# Staged MaxSim kernel (`tiled_reduce_staged`) — design + +> **Status: experimental POC.** This module validates a design hypothesis; it is +> not yet the production MaxSim path. See *Scope* below. + +## Thesis + +One generic, cache-tiled reduction driver — [`tiled_reduce_staged`](driver.rs) — +can compute multi-vector MaxSim/Chamfer for **any element type and any +quantization** by swapping small pluggable *stages*, instead of forking the tiled +loop nest once per datatype. This module proves that claim with two +instantiations: + +- **f32** — bit-identical to, and on par with, the hand-fused V3 kernel. +- **4-bit MinMax-quantized `i8`** — a *new datatype* added by swapping two stages + only, with a real speedup over the per-pair SIMD reference. + +## Shape + +The driver owns tiling/blocking and walks query×doc tiles. Per tile it calls four +pluggable stages, defined in [`mod.rs`](mod.rs): + +| Stage | Trait | Role | +|-------|-------|------| +| A | `StagedKernel` | SIMD inner kernel; writes a per-pair accumulator `Acc` into a `partial` buffer | +| B | `Postprocess` | maps `Acc → Score` (e.g. dequantize) using optional per-call metadata (`scratch_len` + `apply`) | +| C | `Reducer` | folds per-doc scores into the running MaxSim (max) | +| — | `StagedConvert` | optional input conversion at tile load (identity for f32; future: on-the-fly quantize) | + +The driver allocates **all** scratch (`partial`, `scored`, conversion buffers) +from a caller-supplied `ScopedAllocator`. Callers size nothing. + +## What varies per axis (the generality proof) + +| Axis | f32 | 4-bit MinMax (`i8`) | +|------|-----|---------------------| +| Stage A kernel | `StagedF32Kernel` ([`v3.rs`](v3.rs)) | `StagedI8Kernel` ([`i8.rs`](i8.rs)) | +| `Acc` type | `f32` | `i32` | +| Stage B postprocess | `Identity` — `scratch_len = 0`, returns acc ([`maxsim.rs`](maxsim.rs)) | `MinMaxPostprocess` — `a·x + b` dequant → `f32` ([`i8.rs`](i8.rs)) | +| Stage C reducer | `MaxReducer` ([`maxsim.rs`](maxsim.rs)) | `MaxReducer` *(shared, unchanged)* | +| Convert | identity | identity (codes are pre-quantized) | + +Only **Stage A + the `Acc` type + Stage B** change between f32 and quantized; the +driver, tiling, and reducer are reused verbatim. That is the thesis. + +## Evidence + +- **f32 = parity.** The staged f32 path is bit-for-bit equal to the hand-fused V3 + kernel (it *is* the same math, restructured) and within ±1.7% throughput. + - Tests: `staged_matches_fused_v3` ([`v3.rs`](v3.rs)), + `staged_f32_arena_reuse` ([`../../factory.rs`](../factory.rs)). + - Benches: `example/multi-vector-staged.json` (fused vs staged sweep), + `example/multi-vector-3way.json` (reference vs fused vs staged). +- **Quantized = new datatype, real win.** 4-bit MinMax over `i8` codes was added + with **no driver change**; it is correct and runs **1.5–4.1×** faster than the + per-pair SIMD reference across a dim sweep. + - Tests: `staged_i8_matches_minmax_reference`, + `staged_i8_arena_reuse_across_calls`, `staged_i8_multi_tile_tiny_budget` + ([`i8.rs`](i8.rs)). + - Bench: `example/multi-vector-quant.json` (reference vs staged). + - The reference (`MinMaxKernel`) is **not** scalar: its per-pair inner product + over 4-bit codes is itself SIMD. The staged win comes from + fusion / block-transposition / tiling, not from SIMD-vs-scalar. + +## Scratch & allocation + +Driver-owned scratch comes from a passed `ScopedAllocator`. For a +zero-allocation steady state, callers reuse a single-owner resettable bump arena, +[`ResettableArena`](arena.rs): + +- the f32 kernel reuses one via `RefCell` + (`PreparedStaged` in [`../../factory.rs`](../factory.rs)); +- the quantized POC owns one in `QuantStagedQuery` ([`i8.rs`](i8.rs)) and `reset`s + it per call. + +`ResettableArena` is deliberately **not** `Clone`/`Sync`: `reset(&mut self)` is +sound only because the borrow checker forbids resetting while any +`ScopedAllocator` still borrows it. (The shared `BumpAllocator` is grow-only and +`Sync`, so it has no `reset`.) + +## Scope + +- The f32 staged kernel is wired end-to-end and selectable for A/B benchmarking as + `MaxSimIsa::X86_64_V3_Staged` ([`../../isa.rs`](../isa.rs)); it coexists with the + fused `X86_64_V3` path. +- The quantized path is a standalone POC entry (`QuantStagedQuery` / + `QuantStagedDocs`), **not** yet behind a `MaxSimIsa` variant or a productized + storage `Repr`. +- Deferred: a quantized storage `Repr`, folding the quantized path into + `MaxSimIsa` / the factory, V4 (AVX-512) Stage-A kernels, and richer reducers + (argmax / top-k). diff --git a/diskann-quantization/src/multi_vector/distance/kernels/staged/arena.rs b/diskann-quantization/src/multi_vector/distance/kernels/staged/arena.rs new file mode 100644 index 000000000..d7860661a --- /dev/null +++ b/diskann-quantization/src/multi_vector/distance/kernels/staged/arena.rs @@ -0,0 +1,122 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +//! POC-local single-threaded resettable bump arena for the staged quantized +//! kernel. +//! +//! Unlike [`BumpAllocator`](crate::alloc::BumpAllocator) — a `Clone`/`Send`/`Sync` +//! grow-only arena reclaimed only when its last clone drops — this is a +//! *single-owner, single-threaded* arena with O(1) +//! [`reset`](ResettableArena::reset). It is owned by one +//! [`QuantStagedQuery`](super::i8::QuantStagedQuery) and reused across +//! `compute_max_sim` calls, so the staged driver's per-call `partial` / `scored` +//! scratch allocates from it instead of the global heap (zero heap traffic in +//! steady state). +//! +//! `reset` is sound *because* the arena is not shareable: it takes `&mut self`, +//! and the borrow checker therefore guarantees no +//! [`ScopedAllocator`](crate::alloc::ScopedAllocator) borrowing this arena — hence +//! no outstanding allocation — is alive at the rewind point. `deallocate` is a +//! no-op; storage is reclaimed wholesale by `reset` or when the arena drops. + +use std::cell::{Cell, UnsafeCell}; +use std::ptr::NonNull; + +use crate::alloc::{AlignedAllocator, AllocatorCore, AllocatorError, Poly}; + +/// A single-owner, single-threaded resettable bump arena over an owned, +/// 64-byte-aligned byte buffer. Hands out aligned sub-slices by bumping a +/// non-atomic `head`; [`reset`](Self::reset) rewinds it. +pub(crate) struct ResettableArena { + /// Backing storage. `UnsafeCell` legitimizes handing out `*mut` ranges while + /// `allocate` holds only `&self` (mirrors `BumpAllocator`'s buffer). + buffer: Poly, AlignedAllocator>, + /// Bump cursor (non-atomic — this arena is single-threaded). + head: Cell, +} + +impl ResettableArena { + /// Allocate a fresh arena with room for `capacity` bytes, base-aligned to 64. + pub(crate) fn with_capacity(capacity: usize) -> Result { + let buffer = Poly::<[u8], _>::new_uninit_slice(capacity.max(1), AlignedAllocator::A64)?; + let (ptr, alloc) = Poly::into_raw(buffer); + + // SAFETY: `UnsafeCell<[u8]>` shares the layout of `[u8]`, and `MaybeUninit` + // is layout-compatible with `u8` (`u8` is valid for any bit pattern, and bytes + // are only read after the allocator hands them out and the caller writes them). + // `ptr` is non-null, having come from `Poly::into_raw`. + let buffer = unsafe { + Poly::from_raw( + NonNull::new_unchecked(ptr.as_ptr() as *mut UnsafeCell<[u8]>), + alloc, + ) + }; + + Ok(Self { + buffer, + head: Cell::new(0), + }) + } + + /// Rewind the arena, freeing every prior allocation in O(1). + /// + /// The `&mut self` receiver is load-bearing: it makes the borrow checker + /// forbid calling `reset` while any [`ScopedAllocator`](crate::alloc::ScopedAllocator) + /// borrowing this arena (and hence any live allocation) exists — which is what + /// makes rewinding the cursor sound. + pub(crate) fn reset(&mut self) { + self.head.set(0); + } + + /// Total capacity in bytes. + fn capacity(&self) -> usize { + self.buffer.get().len() + } + + /// Base pointer of the backing buffer. + fn base(&self) -> *mut u8 { + self.buffer.get().cast::() + } +} + +// SAFETY: on success `allocate` returns a slice of exactly `layout.size()` bytes +// whose base is aligned to at least `layout.align()` (the running offset is aligned +// up to `layout.align()` relative to the real base address); a request that cannot +// fit the fixed capacity returns an error. `deallocate` is a no-op — storage is +// reclaimed only by `reset` or on drop. +unsafe impl AllocatorCore for ResettableArena { + fn allocate(&self, layout: std::alloc::Layout) -> Result, AllocatorError> { + let base = self.base() as usize; + let head = self.head.get(); + // Align the current free address up to `layout.align()`, then reserve `size`. + let cur = base.checked_add(head).ok_or(AllocatorError)?; + let aligned = cur + .checked_next_multiple_of(layout.align()) + .ok_or(AllocatorError)?; + let pad = aligned - cur; + let new_head = head + .checked_add(pad) + .and_then(|h| h.checked_add(layout.size())) + .ok_or(AllocatorError)?; + if new_head > self.capacity() { + return Err(AllocatorError); + } + self.head.set(new_head); + + // SAFETY: `head + pad <= new_head <= capacity`, so the offset is in-bounds of + // the backing buffer and the range `[off, off + size)` lies within it. + let ptr = unsafe { self.base().add(head + pad) }; + NonNull::new(std::ptr::slice_from_raw_parts_mut(ptr, layout.size())).ok_or(AllocatorError) + } + + unsafe fn deallocate(&self, _ptr: NonNull<[u8]>, _layout: std::alloc::Layout) {} +} + +impl std::fmt::Debug for ResettableArena { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ResettableArena") + .field("capacity", &self.capacity()) + .field("head", &self.head.get()) + .finish() + } +} diff --git a/diskann-quantization/src/multi_vector/distance/kernels/staged/driver.rs b/diskann-quantization/src/multi_vector/distance/kernels/staged/driver.rs new file mode 100644 index 000000000..194b527a8 --- /dev/null +++ b/diskann-quantization/src/multi_vector/distance/kernels/staged/driver.rs @@ -0,0 +1,227 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +//! Staged tiling driver. +//! +//! Structurally identical to [`super::super::tiled_reduce`] (A-tile → B-tile), but +//! the inner work per A-panel is split into three stages: Stage A fills +//! `partial_buf` for the whole B-tile, Stage B ([`Postprocess::apply`]) turns the +//! raw `Acc` block into `Score`s, and Stage C ([`Reducer::fold_block`]) folds them +//! into the running state. The driver runs all three uniformly; for the identity +//! postprocess (`Acc == Score`) Stage B is `#[inline(always)]` and returns the +//! `partial_buf` pointer unchanged, so it compiles away (no `scored_buf`, no pass). +//! +//! Every B-tile is `≤ b_tile_rows` rows: a full tile has `b_panels_per_tile` +//! complete panels, and a final short tile may end in a `< B_PANEL` remainder +//! panel — both handled by one loop (`full_panels` + an optional `tail`), so there +//! is no separate "peeled tail" code path. + +use diskann_wide::Architecture; + +use super::super::TileBudget; +use super::super::layouts::Layout; +use super::{FoldCtx, Postprocess, Reducer, StagedConvert, StagedKernel, StagedPlan}; +use crate::alloc::{Poly, ScopedAllocator}; + +/// Run the staged loop. `state` (len `a_padded_nrows`) is the persistent running +/// reduction across all B-tiles (the caller's output buffer). All transient +/// scratch — `partial` (Stage A output), the Stage-B `scored` region, and the +/// conversion buffers — is allocated **internally** from `alloc`; the caller +/// sizes nothing and hands in only the allocator. +/// +/// # Safety +/// +/// * `a_ptr` valid for `a_padded_nrows * k` `LA::Element`; `a_padded_nrows` +/// a multiple of `SK::A_PANEL`. +/// * `b_ptr` valid for `b_nrows * k` `LB::Element`. +#[allow(clippy::too_many_arguments, clippy::expect_used)] +pub(super) unsafe fn tiled_reduce_staged( + arch: A, + ca: &LA, + cb: &LB, + post: &P, + a_ptr: *const LA::Element, + a_padded_nrows: usize, + b_ptr: *const LB::Element, + b_nrows: usize, + k: usize, + state: &mut [R::State], + alloc: ScopedAllocator<'_>, + budget: TileBudget, +) where + A: Architecture, + SK: StagedKernel, + P: Postprocess, + R: Reducer, + LA: StagedConvert, + LB: StagedConvert, +{ + let a_panel = SK::A_PANEL; + let b_panel = SK::B_PANEL; + + // Initialize the running reduction state. + for s in state[..a_padded_nrows].iter_mut() { + *s = R::init(); + } + + // Zero-dimensional vectors: every IP is 0. The caller fills the score for + // this degenerate case; here we just avoid the zero-stride tiling nest. + if k == 0 { + return; + } + + debug_assert_eq!( + a_padded_nrows % a_panel, + 0, + "a_padded_nrows must be a multiple of A_PANEL" + ); + + let acc_bytes = core::mem::size_of::(); + let a_row_bytes = k * core::mem::size_of::<::Element>(); + let b_row_bytes = k * core::mem::size_of::<::Element>(); + let plan = StagedPlan::new( + a_row_bytes, + b_row_bytes, + a_panel, + b_panel, + acc_bytes, + budget, + ); + + let a_tile_rows = a_panel * plan.a_panels_per_tile; + let b_tile_rows = b_panel * plan.b_panels_per_tile; + + let a_kern_panel_stride = a_panel * k; + let b_kern_panel_stride = b_panel * k; + + // Conversion staging buffers, also from the caller's allocator — 0-length + // (a no-op dangling allocation) for the identity conversions every current + // staged kernel uses. Sized by the staged-local `StagedConvert` contract, so + // the staged driver never touches the shared `ConvertTo` machinery. + let a_conv_len = ca.scratch_len(a_tile_rows.min(a_padded_nrows), k); + let mut a_conv = + Poly::<[::Element], _>::new_uninit_slice(a_conv_len, alloc) + .expect("a-side conversion scratch allocation"); + let b_conv_len = cb.scratch_len(b_tile_rows.min(b_nrows), k); + let mut b_conv = + Poly::<[::Element], _>::new_uninit_slice(b_conv_len, alloc) + .expect("b-side conversion scratch allocation"); + let a_conv_ptr = a_conv.as_mut_ptr().cast::<::Element>(); + let b_conv_ptr = b_conv.as_mut_ptr().cast::<::Element>(); + + // Internal scratch, allocated from the caller's allocator — the caller sizes + // nothing. `partial` is Stage A's output (the kernel declares its size via + // `StagedKernel::partial_len`); `scored` is Stage B's output, sized by the + // postprocess contract (a 0-length, no-op dangling allocation for the identity + // postprocess). Every B-tile is `≤ b_tile_rows` wide, so `b_tile_rows` is the + // exact upper bound on `valid_b_cols`. Both `Poly`s live to the end of the + // call, then free via `alloc` (a global free, or a no-op for a bump allocator). + let partial_len = SK::partial_len(k, budget); + let mut partial = Poly::<[SK::Acc], _>::new_uninit_slice(partial_len, alloc) + .expect("partial scratch allocation"); + let scored_len = post.scratch_len(a_panel, b_tile_rows); + let mut scored = Poly::<[P::Score], _>::new_uninit_slice(scored_len, alloc) + .expect("scored scratch allocation"); + + let partial_ptr = partial.as_mut_ptr().cast::(); + let scored_ptr = scored.as_mut_ptr().cast::(); + let state_ptr = state.as_mut_ptr(); + + // SAFETY: all pointer arithmetic stays within the respective allocations; + // this mirrors `super::super::tiled_reduce`'s established bounds. + unsafe { + let mut rows_done: usize = 0; + + // Loop 1: A tiles. + while rows_done < a_padded_nrows { + let tile_rows = a_tile_rows.min(a_padded_nrows - rows_done); + let pa_tile_src = a_ptr.add(rows_done * k); + let pr_tile = state_ptr.add(rows_done); + + let pa_tile = ca.convert(a_conv_ptr, arch, pa_tile_src, tile_rows, k); + let pa_tile_end = pa_tile.add(tile_rows * k); + + // Loop 2: B tiles. Each is `bt_rows = min(b_tile_rows, remaining)` — + // full tiles end on a panel boundary (`tail == 0`); the final short + // tile may carry a `< B_PANEL` remainder panel. + let mut pb_tile_src = b_ptr; + let mut b_row_offset = 0usize; + while b_row_offset < b_nrows { + let bt_rows = b_tile_rows.min(b_nrows - b_row_offset); + let pb_tile = cb.convert(b_conv_ptr, arch, pb_tile_src, bt_rows, k); + let full_panels = bt_rows / b_panel; + let tail = bt_rows % b_panel; + + // Loop 3: A micro-panels. + // + // Partial-buffer granularity: one A-panel (P_a = 1) against the + // whole B-tile (P_b = b_panels_per_tile). Loop 4 below runs ONLY + // Stage A, so the kernel stays hot in i-cache for the entire + // B-tile; Stage C then folds the whole block in one pass. See + // `StagedPlan` and docs/staged_multi_vector_kernel.md §5. + let mut pa_panel = pa_tile; + let mut pr_panel = pr_tile; + let mut a_row_offset = rows_done; + while pa_panel < pa_tile_end { + // Stage A: fill partial_buf for this A-panel across the B-tile + // (the full panels, then a `< B_PANEL` remainder panel if any). + let mut pb_panel = pb_tile; + let mut col = 0usize; + for _ in 0..full_panels { + SK::full_panel( + arch, + pa_panel, + pb_panel, + k, + partial_ptr.add(col * a_panel), + a_panel, + ); + pb_panel = pb_panel.add(b_kern_panel_stride); + col += b_panel; + } + if tail > 0 { + SK::partial_panel( + arch, + tail, + pa_panel, + pb_panel, + k, + partial_ptr.add(col * a_panel), + a_panel, + ); + } + + // Stage B (Acc -> Score): identity returns `partial_ptr` for + // free; a quantized post fills `scored` (the driver-allocated + // region), using the global (a_row_offset, b_row_offset) to + // index its metadata. Output column stride is contractually + // `a_panel`. + let scores = post.apply( + scored_ptr, + arch, + partial_ptr, + FoldCtx { + a_panel, + valid_b_cols: bt_rows, + b_stride: a_panel, + a_row_offset, + b_row_offset, + }, + ); + // Stage C: fold the scores into the running state (one fold + // per A-panel × B-tile — the widest, cheapest fold). + R::fold_block(arch, pr_panel, scores, a_panel, bt_rows, a_panel); + + pa_panel = pa_panel.add(a_kern_panel_stride); + pr_panel = pr_panel.add(a_panel); + a_row_offset += a_panel; + } + + pb_tile_src = pb_tile_src.add(bt_rows * k); + b_row_offset += bt_rows; + } + + rows_done += tile_rows; + } + } +} diff --git a/diskann-quantization/src/multi_vector/distance/kernels/staged/i8.rs b/diskann-quantization/src/multi_vector/distance/kernels/staged/i8.rs new file mode 100644 index 000000000..a51c5a6a6 --- /dev/null +++ b/diskann-quantization/src/multi_vector/distance/kernels/staged/i8.rs @@ -0,0 +1,962 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +//! V3 (AVX2) **quantized** staged path: 4-bit MinMax MaxSim. +//! +//! This is the payoff the staged framework was built for — the first +//! **non-identity Stage B**. Stage A accumulates a raw *integer* dot product +//! (`Acc = i32`), Stage B ([`MinMaxPostprocess`]) turns each `i32` into the +//! finished MinMax inner product (`Score = f32`) using per-vector +//! scale/center/sum metadata, and Stage C ([`MaxReducer`](super::maxsim::MaxReducer)) +//! folds the `f32` scores exactly as in the f32 path — **unchanged**. +//! +//! # Stage A mirrors the f32 kernel (block-transposed + broadcast, no reduction) +//! +//! The integer micro-kernel is structurally the f32 [`store_microkernel`] with +//! `PACK = 2` integer MACs: +//! +//! * **Query** (`Left`): [`BlockTransposed`] — codes widened `u8→i16` +//! once at build time, then block-transposed with two K-columns interleaved per +//! row (`[r0_k0, r0_k1, r1_k0, r1_k1, …]`). One col-pair is `GROUP·PACK = 32` +//! `i16` = two `i16x16` halves (rows `0..8` / `8..16`). +//! * **Doc** (`Right`): [`RowMajor`] — codes stream in place; the kernel reads +//! each doc col's 2-K word, widens it, and broadcasts it as +//! `[d_k0, d_k1, d_k0, d_k1, …]`. +//! * **MAC**: `i32x8::dot_simd(query_half, doc_broadcast)` — i.e. `vpmaddwd` +//! (`_mm256_madd_epi16`), which sums each interleaved K-pair into one `i32` +//! lane. Each lane **is** one A-row's running dot for that doc col, so the +//! accumulators *are* the outputs — no per-pair horizontal reduction, exactly +//! the property block-transposition buys in the f32 kernel. +//! +//! 4-bit codes are `u8 ∈ [0, 15]` (not nibble-packed), widened to `i16 ∈ [0, 15]`. +//! A 2-K partial is `≤ 2·15·15 = 450` and the full dot over `dim ≤ 512` is +//! `≤ 15·15·512 ≈ 1.2e5`, so the `i32` accumulation **never overflows** — the +//! integer dot is *exact*. No new `diskann-wide` op is needed: the existing +//! [`SIMDDotProduct`] for `i32x8` is the right shape in broadcast config. +//! +//! # Even-K contract (zero driver change) +//! +//! The shared [driver](super::driver) derives the A-side physical row stride from +//! its `k` argument (`rows_done · k == block_offset`), which only matches a +//! block-transposed query when `k == padded_ncols`. So the entry point passes the +//! query's **padded** (even) column count as the driver `k` and requires the doc +//! to match (zero-padded to even). The padding column holds `0` on *both* sides, +//! so it contributes `0` to every dot — the IP is unchanged, and the kernel walks +//! exactly `k/2` full K-pairs with no odd-K tail branch. + +use std::num::NonZeroUsize; + +use diskann_utils::ReborrowMut; +use diskann_wide::arch::x86_64::V3; +use diskann_wide::{SIMDCast, SIMDDotProduct, SIMDMulAdd, SIMDReinterpret, SIMDVector}; + +use super::super::TileBudget; +use super::super::layouts; +use super::arena::ResettableArena; +use super::driver::tiled_reduce_staged; +use super::maxsim::MaxReducer; +use super::{FoldCtx, Postprocess, StagedKernel}; +use crate::CompressInto; +use crate::algorithms::Transform; +use crate::algorithms::transforms::NullTransform; +use crate::alloc::ScopedAllocator; +use crate::minmax::{MinMaxCompensation, MinMaxMeta, MinMaxQuantizer}; +use crate::multi_vector::{BlockTransposed, BlockTransposedRef, Defaulted, Mat, MatRef, Standard}; +use crate::num::Positive; + +diskann_wide::alias!(i16s = ::i16x16); +diskann_wide::alias!(i32s = ::i32x8); +diskann_wide::alias!(u32s = ::u32x8); +diskann_wide::alias!(f32s = ::f32x8); + +// ── Stage A: integer store-out micro-kernel ────────────────────── + +/// Zero-sized Stage-A kernel marker for the quantized (4-bit MinMax) staged path +/// with block size `GROUP`. +pub(crate) struct StagedI8Kernel; + +// SAFETY: `full_panel`/`partial_panel` read A_PANEL(16) i16 query rows × K +// (block-transposed, K padded to even) and UNROLL × K u8 doc elements, and write +// UNROLL columns of A_PANEL(16) i32 into `partial` at stride `partial_b_stride` — +// all within the bounds the `StagedKernel` contract guarantees. +unsafe impl StagedKernel for StagedI8Kernel<16> { + type Left = layouts::BlockTransposed; + type Right = layouts::RowMajor; + type Acc = i32; + const A_PANEL: usize = 16; + const B_PANEL: usize = 4; + + #[inline(always)] + unsafe fn full_panel( + arch: V3, + a: *const i16, + b: *const u8, + k: usize, + partial: *mut i32, + partial_b_stride: usize, + ) { + // SAFETY: pointer validity per the `StagedKernel` contract. + unsafe { + int_store_microkernel::<{ Self::B_PANEL }>(arch, a, b, k, partial, partial_b_stride) + } + } + + #[inline(always)] + unsafe fn partial_panel( + arch: V3, + remainder: usize, + a: *const i16, + b: *const u8, + k: usize, + partial: *mut i32, + partial_b_stride: usize, + ) { + // SAFETY: pointer validity per the `StagedKernel` contract. + unsafe { + match remainder { + 1 => int_store_microkernel::<1>(arch, a, b, k, partial, partial_b_stride), + 2 => int_store_microkernel::<2>(arch, a, b, k, partial, partial_b_stride), + 3 => int_store_microkernel::<3>(arch, a, b, k, partial, partial_b_stride), + _ => unreachable!( + "unexpected remainder {remainder} for B_PANEL={}", + Self::B_PANEL + ), + } + } + } +} + +/// V3 integer store-out micro-kernel: 16 A-rows × `UNROLL` B-rows. +/// +/// Mirrors [`super::v3::store_microkernel`] with `PACK = 2` integer MACs (see the +/// module docs). The epilogue stores each B-column's 16 A-row `i32` accumulators +/// into `partial` (A-major: column `j` at `partial + j*b_stride`, as two `i32x8` +/// halves) — identical contract to the f32 kernel, so Stage B / Stage C are +/// unchanged. +/// +/// # Safety +/// +/// 1. `a_packed` points to a block-transposed query block of `16 * k` `i16` +/// (`k` even — the padded column count). +/// 2. `b` points to `UNROLL` rows of `k` contiguous `u8` (`k` even). +/// 3. `partial` is valid for `UNROLL` columns of 16 `i32` at stride `b_stride`. +#[inline(always)] +unsafe fn int_store_microkernel( + arch: V3, + a_packed: *const i16, + b: *const u8, + k: usize, + partial: *mut i32, + b_stride: usize, +) { + let mut p0 = [i32s::default(arch); UNROLL]; + let mut p1 = [i32s::default(arch); UNROLL]; + let offsets: [usize; UNROLL] = core::array::from_fn(|j| k * j); + + // One col-pair of the block-transposed query = GROUP·PACK = 32 i16 = two + // `i16x16` halves (rows 0..8 low, rows 8..16 high). + let a_pair_stride = 2 * i16s::LANES; + let a_half = i16s::LANES; + let pairs = k / 2; // `k` is even (the padded column count) ⇒ exact, no tail. + + for p in 0..pairs { + // SAFETY: precondition 1 — the query block has `pairs` col-pairs of 32 i16. + let (a0, a1) = unsafe { + ( + i16s::load_simd(arch, a_packed.add(a_pair_stride * p)), + i16s::load_simd(arch, a_packed.add(a_pair_stride * p + a_half)), + ) + }; + + for j in 0..UNROLL { + // SAFETY: precondition 2 — doc col j is `offsets[j]` in, and + // `2*p + 1 < k` because `pairs == k/2`. + let (d0, d1) = unsafe { + let base = 2 * p + offsets[j]; + ( + u32::from(b.add(base).read()), + u32::from(b.add(base + 1).read()), + ) + }; + // Broadcast the K-pair [d0, d1] across all 8 i32 lanes, reinterpreted + // as i16x16 = [d0, d1, d0, d1, …] — the shape `madd_epi16` consumes + // (pairing each query [k0, k1] with [d0, d1] into one i32 lane). + let packed = d0 | (d1 << 16); + let bcast: i16s = u32s::splat(arch, packed).reinterpret_simd(); + p0[j] = p0[j].dot_simd(a0, bcast); + p1[j] = p1[j].dot_simd(a1, bcast); + } + } + + for j in 0..UNROLL { + // SAFETY: precondition 3 — column j occupies [j*b_stride, j*b_stride+16) i32. + unsafe { + p0[j].store_simd(partial.add(j * b_stride)); + p1[j].store_simd(partial.add(j * b_stride + i32s::LANES)); + } + } +} + +// ── Stage B: integer code → MinMax inner product ───────────────── + +/// Stage B for 4-bit MinMax: convert each raw integer dot `⟨codes⟩` (the `i32` +/// `Acc` from Stage A) into the finished MinMax inner product +/// +/// ```text +/// IP = qm.a·dm.a·⟨codes⟩ + qm.n·dm.b + dm.n·qm.b + qm.b·dm.b·dim +/// ``` +/// +/// (the linear decomposition in `minmax::vectors`), emitting **+IP** so Stage C +/// folds `max` and the caller negates once at the end (`min distance = +/// max_a(-IP) = -max_a IP`). +/// +/// This is the first non-identity [`Postprocess`]: it reports a non-zero +/// [`scratch_len`](Postprocess::scratch_len) (so the driver allocates an f32 +/// region) and [`apply`] writes the converted scores into it, indexing its +/// per-vector metadata by the *global* row offsets in [`FoldCtx`]. +pub(crate) struct MinMaxPostprocess<'m> { + /// Per-query-vector metadata, indexed by `ctx.a_row_offset + i`. Length must + /// be `≥ padded_nrows` (padded rows carry default metadata; their scores are + /// computed but never read). + query_meta: &'m [MinMaxCompensation], + /// Per-doc-vector metadata, indexed by `ctx.b_row_offset + c`. Length `≥ nd`. + doc_meta: &'m [MinMaxCompensation], + /// The **logical** dimension (the `dim` term in the IP formula). The integer + /// dot is taken over the padded columns, but the extra column's codes are `0` + /// on both sides, so `⟨codes⟩` is unchanged. + dim: f32, +} + +impl<'m> MinMaxPostprocess<'m> { + pub(crate) fn new( + query_meta: &'m [MinMaxCompensation], + doc_meta: &'m [MinMaxCompensation], + dim: usize, + ) -> Self { + Self { + query_meta, + doc_meta, + dim: dim as f32, + } + } +} + +// SAFETY: `apply` reads exactly `ctx.valid_b_cols` columns of `ctx.a_panel` `i32` +// from `acc` at stride `ctx.b_stride`, writes only the corresponding +// `ctx.a_panel × ctx.valid_b_cols` region of `scratch` (the driver allocates +// `scratch_len(a_panel, max_b_cols) = a_panel · max_b_cols ≥ a_panel · +// valid_b_cols` `f32`), and returns a pointer into it. Metadata indices +// `a_row_offset + i < padded_nrows ≤ query_meta.len()` and `b_row_offset + c < nd +// ≤ doc_meta.len()` are in bounds by the entry-point's preconditions. +// +// V3-specific: the quantized kernel only runs on V3 (`StagedI8Kernel: StagedKernel`), +// so Stage B is implemented for V3 only and uses AVX2 for the score conversion. +unsafe impl Postprocess for MinMaxPostprocess<'_> { + type Acc = i32; + type Score = f32; + + #[inline] + fn scratch_len(&self, a_panel: usize, max_b_cols: usize) -> usize { + a_panel * max_b_cols + } + + #[inline] + unsafe fn apply( + &self, + scratch: *mut f32, + arch: V3, + acc: *const i32, + ctx: FoldCtx, + ) -> *const f32 { + let out = scratch; + + // Rewrite the per-(row i, col c) IP into a per-row-vector form whose only + // per-column inputs are three doc scalars (so the 16-row inner loop is a + // straight SIMD sweep): + // ip = dm.a·(qm.a·raw) + dm.b·qm.n + (dm.n + dm.b·dim)·qm.b + // = A_c·(qa·raw) + B_c·qn + C_c·qb, + // with A_c=dm.a, B_c=dm.b, C_c=dm.n+dm.b·dim, and qa/qn/qb the per-query-row + // metadata. This is the reference formula regrouped (within f32 rounding). + // The quantized kernel always uses A_PANEL = 16 = 2*LANES, so Stage B is a + // pure SIMD sweep — no scalar fallback for other panel widths. + let lanes = f32s::LANES; + debug_assert_eq!( + ctx.a_panel, + 2 * lanes, + "quantized Stage B expects A_PANEL == 2*LANES" + ); + + // Gather the per-A-row metadata into contiguous SoA arrays once (the + // source `MinMaxCompensation` is AoS, so a strided scalar gather), then + // hold them in registers across the whole B-column sweep. + let mut qa = [0.0f32; 16]; + let mut qb = [0.0f32; 16]; + let mut qn = [0.0f32; 16]; + for i in 0..16 { + let qm = self.query_meta[ctx.a_row_offset + i]; + qa[i] = qm.a; + qb[i] = qm.b; + qn[i] = qm.n; + } + // SAFETY: each array holds exactly 16 = 2·LANES f32. + let (qa0, qa1, qb0, qb1, qn0, qn1) = unsafe { + ( + f32s::load_simd(arch, qa.as_ptr()), + f32s::load_simd(arch, qa.as_ptr().add(lanes)), + f32s::load_simd(arch, qb.as_ptr()), + f32s::load_simd(arch, qb.as_ptr().add(lanes)), + f32s::load_simd(arch, qn.as_ptr()), + f32s::load_simd(arch, qn.as_ptr().add(lanes)), + ) + }; + + for c in 0..ctx.valid_b_cols { + let dm = self.doc_meta[ctx.b_row_offset + c]; + let a_c = f32s::splat(arch, dm.a); + let b_c = f32s::splat(arch, dm.b); + let c_c = f32s::splat(arch, dm.n + dm.b * self.dim); + let acc_col = c * ctx.b_stride; + let out_col = c * ctx.a_panel; + // SAFETY: `acc_col + 2·LANES ≤ valid_b_cols · b_stride`; the partial + // block is valid for that many i32, and `out_col + 2·LANES ≤ buf.len()`. + unsafe { + let raw0 = i32s::load_simd(arch, acc.add(acc_col)).simd_cast(); + let raw1 = i32s::load_simd(arch, acc.add(acc_col + lanes)).simd_cast(); + // a_c·(qa·raw) + (b_c·qn + c_c·qb) + let s0 = a_c.mul_add_simd(qa0 * raw0, b_c.mul_add_simd(qn0, c_c * qb0)); + let s1 = a_c.mul_add_simd(qa1 * raw1, b_c.mul_add_simd(qn1, c_c * qb1)); + s0.store_simd(out.add(out_col)); + s1.store_simd(out.add(out_col + lanes)); + } + } + scratch.cast_const() + } +} + +// ── Public POC entry: prepared 4-bit MinMax staged MaxSim ──────── + +/// Quantize an f32 multi-vector to 4-bit MinMax (Null transform, scale 1.0) — +/// the shared quantizer for the public query/doc builders so both sides decode +/// to comparable codes + metadata. +#[allow(clippy::expect_used)] // POC constructor: inputs are pre-validated by the caller. +fn quantize_minmax_4bit(input: MatRef<'_, Standard>) -> Mat> { + let dim = input.vector_dim(); + let n = input.num_vectors(); + let q = MinMaxQuantizer::new( + Transform::Null(NullTransform::new( + NonZeroUsize::new(dim).expect("dimension must be non-zero"), + )), + Positive::new(1.0).expect("1.0 is positive"), + ); + let mut out: Mat> = + Mat::new(MinMaxMeta::new(n, dim), Defaulted).expect("MinMaxMeta allocation"); + q.compress_into(input, out.reborrow_mut()) + .expect("input must be finite (no NaN)"); + out +} + +/// A prepared 4-bit MinMax **query** set for the staged MaxSim kernel (V3/AVX2). +/// +/// Built once from an f32 multi-vector; [`compute_max_sim`](Self::compute_max_sim) +/// is the per-document-set hot path. It owns a `ResettableArena` that backs the +/// staged driver's per-call `partial` / Stage-B scratch and a reused `state` +/// buffer, so **steady-state calls perform no heap allocation** (the arena is +/// reset, not reallocated, each call). +/// +/// This is a **standalone POC entry**: the quantized path is intentionally *not* +/// yet unified into [`MaxSimIsa`](crate::multi_vector::distance::MaxSimIsa) / +/// `build_max_sim`, which ties to a productized `QuantizedSoa` matrix `Repr` (see +/// [`QuantStagedDocs`] and `docs/staged_multi_vector_kernel.md`). +pub struct QuantStagedQuery { + /// Codes widened `u8→i16` and block-transposed (`PACK=2`) for Stage A. + query: BlockTransposed, + /// Per-vector metadata, padded to `query.padded_nrows()` (padded rows carry + /// default metadata; their scores are computed but never read). + meta: Vec, + /// Logical dimension. + dim: usize, + arch: V3, + /// Reusable running-reduction output (len `query.padded_nrows()`); the driver + /// re-initialises it each call, so it is reused, not reallocated. + state: Vec, + /// Reusable arena backing the driver's transient `partial` / `scored` scratch; + /// reset (not reallocated) at the top of every call. + arena: ResettableArena, +} + +impl QuantStagedQuery { + /// Quantize `query` to 4-bit MinMax and prepare the block-transposed layout + + /// the reusable scratch arena. Returns `None` if AVX2 (V3) is unavailable on + /// this host. + #[allow(clippy::expect_used)] // POC constructor: dims are valid by construction. + pub fn build(query: MatRef<'_, Standard>) -> Option { + let arch = V3::new_checked()?; + let dim = query.vector_dim(); + let nq = query.num_vectors(); + + let q_mat = quantize_minmax_4bit(query); + + let mut codes = vec![0i16; nq * dim]; + for r in 0..nq { + let row = q_mat.get_row(r).expect("row r < nq"); + for j in 0..dim { + codes[r * dim + j] = i16::from(row.vector().get(j).expect("col j < dim") as u8); + } + } + let code_view = MatRef::new(Standard::::new(nq, dim).expect("nq×dim i16"), &codes) + .expect("code slice length"); + let bt = BlockTransposed::::from_matrix_view(code_view.as_matrix_view()); + + let padded_nrows = bt.padded_nrows(); + let mut meta = vec![MinMaxCompensation::default(); padded_nrows]; + for (r, m) in meta.iter_mut().enumerate().take(nq) { + *m = q_mat.get_row(r).expect("row r < nq").meta(); + } + + // The driver allocates only `partial` + `scored` from the arena (the + // identity conversion buffers are zero-length). `StagedPlan` co-budgets so + // each is `<= l1_b`, so `2 * l1_b` is a provable upper bound for *any* k + // (no per-shape sizing needed); add one page of headroom for alignment. + let arena_bytes = 2 * TileBudget::default().l1_b + 4096; + let arena = ResettableArena::with_capacity(arena_bytes).expect("staged arena allocation"); + + Some(Self { + query: bt, + meta, + dim, + arch, + state: vec![f32::MIN; padded_nrows], + arena, + }) + } + + /// Whether this host supports the quantized staged kernel (requires AVX2 / + /// the `V3` ISA). Use this to gate before calling [`build`](Self::build). + pub fn is_supported() -> bool { + V3::new_checked().is_some() + } + + /// Number of (logical) query vectors. + pub fn num_vectors(&self) -> usize { + self.query.nrows() + } + + /// Compute the per-query **min distance** (`= -max_d IP`) against `docs`, + /// writing one score per query vector into `scores`. + /// + /// # Panics + /// + /// Panics if `scores.len() != self.num_vectors()` or the query and doc + /// logical dimensions differ. + #[allow(clippy::expect_used)] // doc view length is guaranteed by construction. + pub fn compute_max_sim(&mut self, docs: &QuantStagedDocs, scores: &mut [f32]) { + let nq = self.query.nrows(); + assert_eq!( + scores.len(), + nq, + "scores length {} must equal query vector count {nq}", + scores.len() + ); + assert_eq!( + self.dim, docs.dim, + "query dim {} != doc dim {}", + self.dim, docs.dim + ); + + let doc = MatRef::new( + Standard::::new(docs.nv, docs.padded_dim).expect("nv×padded_dim u8"), + &docs.codes, + ) + .expect("doc code slice length"); + let post = MinMaxPostprocess::new(&self.meta, &docs.meta, self.dim); + + // Rewind the arena so the driver's `partial` / `scored` reuse last call's + // storage (`&mut self` proves no prior allocation is still borrowing it), + // then hand it the arena and the reused `state` output. Steady-state: no + // heap allocation. + let padded = self.query.padded_nrows(); + self.arena.reset(); + max_ip_kernel_staged_i8( + self.arch, + self.query.as_view(), + doc, + &post, + &mut self.state[..padded], + ScopedAllocator::new(&self.arena), + TileBudget::default(), + ); + + for (s, &raw) in scores.iter_mut().zip(self.state.iter()) { + *s = -raw; // min distance = -(max inner product) + } + } +} + +/// A prepared 4-bit MinMax **document** set: the minimal "codes-together, +/// metadata-together" SoA the staged kernel streams. Stage A reads the contiguous +/// codes region (row-major `u8`, `padded_dim` per vector); Stage B reads the +/// contiguous metadata region (one [`MinMaxCompensation`] per vector). +/// +/// This is the doc-side storage the kernel was designed around — the interleaved +/// `MinMaxMeta` `Repr` (one codes+meta blob per row) *cannot* be streamed by a +/// kernel that needs the codes contiguous for SIMD and the metadata only in the +/// postprocess. +/// +/// # Productization (assessed, not built): a `QuantizedSoa` matrix `Repr` +/// +/// The natural next step is a matrix `Repr` that owns this layout in one +/// allocation, `[codes region | aligned metadata region]`, with `Row<'a> = (&'a +/// [u8], &'a MinMaxCompensation)`. It fits the existing `Repr`/`ReprOwned` +/// contract (`layout()` = `codes_bytes + pad + meta_bytes`; `get_row` splits the +/// two regions) and needs a `CompressInto` that emits SoA from `MinMaxQuantizer` +/// (today's emits the interleaved blob). That `Repr` is also the prerequisite for +/// unifying the quantized path into `MaxSimIsa`; this owning prototype keeps the +/// POC self-contained without committing to it. +pub struct QuantStagedDocs { + /// Row-major codes, `nv * padded_dim` `u8` (each `∈ [0, 15]` for 4-bit), with + /// the trailing (padded) column zeroed when `dim` is odd. + codes: Vec, + /// Per-vector metadata, `nv` entries. + meta: Vec, + /// Logical dimension (the IP formula's `dim`). + dim: usize, + /// Physical (even) column count `next_multiple_of(dim, 2)` — the doc row + /// stride, which must equal the query's padded column count. + padded_dim: usize, + /// Number of vectors. + nv: usize, +} + +impl QuantStagedDocs { + /// Quantize `docs` to 4-bit MinMax and pack the codes-together / + /// metadata-together SoA (codes zero-padded to an even `padded_dim`). + #[allow(clippy::expect_used)] // POC constructor: dims are valid by construction. + pub fn build(docs: MatRef<'_, Standard>) -> Self { + let dim = docs.vector_dim(); + let nv = docs.num_vectors(); + let padded_dim = dim.next_multiple_of(2); + + let d_mat = quantize_minmax_4bit(docs); + let mut codes = vec![0u8; nv * padded_dim]; + let mut meta = Vec::with_capacity(nv); + for r in 0..nv { + let row = d_mat.get_row(r).expect("row r < nv"); + for j in 0..dim { + codes[r * padded_dim + j] = row.vector().get(j).expect("col j < dim") as u8; + } + meta.push(row.meta()); + } + Self { + codes, + meta, + dim, + padded_dim, + nv, + } + } + + /// Number of document vectors. + pub fn num_vectors(&self) -> usize { + self.nv + } +} + +// ── Entry point ────────────────────────────────────────────────── + +/// Compute per-query-vector max MinMax inner product into `state` via the staged +/// quantized pipeline. `state` (len ≥ `query.padded_nrows()`) is the caller's +/// output, left holding the raw max-IP (the caller negates for min-distance). +/// Transient scratch (`partial`, Stage-B region) is allocated internally from +/// `alloc` — the caller sizes nothing. +/// +/// `query` is the block-transposed (`i16`, `GROUP=16`, `PACK=2`) widened codes; +/// `doc` is the row-major `u8` codes at stride `query.padded_ncols()` (even); and +/// `post` carries the query/doc metadata + logical `dim`. +/// +/// # Panics +/// +/// Panics if `state.len() < query.padded_nrows()` or `query.padded_ncols() != +/// doc.vector_dim()` (the even-K contract — see the module docs). +pub(crate) fn max_ip_kernel_staged_i8( + arch: V3, + query: BlockTransposedRef<'_, i16, 16, 2>, + doc: MatRef<'_, Standard>, + post: &MinMaxPostprocess<'_>, + state: &mut [f32], + alloc: ScopedAllocator<'_>, + budget: TileBudget, +) { + let padded = query.padded_nrows(); + // `k` is the *padded* (even) column count: the driver derives the A-side + // physical row stride from it, and it must match the doc stride. + let k = query.padded_ncols(); + if state.len() < padded || k != doc.vector_dim() { + max_ip_kernel_staged_i8_panic(state.len(), padded, k, doc.vector_dim()); + } + + let b_nrows = doc.num_vectors(); + + // Empty contraction: every IP reduces to the metadata-only terms. The POC + // does not exercise `dim == 0`; fill 0 and bail rather than enter the tiling + // nest with a zero stride (matches the f32 entry's degenerate guard). + if k == 0 { + state[..padded].fill(0.0); + return; + } + + let ca = layouts::BlockTransposed::::new(); + let cb = layouts::RowMajor::::new(); + + // SAFETY: + // - `query.as_ptr()` is valid for `padded * k` i16 (block-transposed, K padded + // to even == k), and `padded` is a multiple of GROUP == A_PANEL == 16. + // - `doc.as_slice()` is `b_nrows * k` contiguous u8 (k == doc.vector_dim()). + // - `state.len() >= padded` (checked); the driver allocates its scratch from + // `alloc`. + unsafe { + tiled_reduce_staged::, MinMaxPostprocess<'_>, MaxReducer, _, _>( + arch, + &ca, + &cb, + post, + query.as_ptr(), + padded, + doc.as_slice().as_ptr(), + b_nrows, + k, + &mut state[..padded], + alloc, + budget, + ); + } +} + +#[inline(never)] +#[cold] +#[allow(clippy::panic)] +fn max_ip_kernel_staged_i8_panic(state_len: usize, padded: usize, k: usize, doc_dim: usize) { + panic!( + "max_ip_kernel_staged_i8: precondition failed: \ + state.len()={state_len} (expected >= {padded}), \ + padded_ncols(k)={k}, doc.vector_dim()={doc_dim} (must be equal — even-K contract)" + ); +} + +#[cfg(test)] +mod tests { + use std::num::NonZeroUsize; + + use diskann_utils::ReborrowMut; + use diskann_wide::arch::x86_64::V3; + + use super::super::super::TileBudget; + use super::{ + MinMaxPostprocess, QuantStagedDocs, QuantStagedQuery, int_store_microkernel, + max_ip_kernel_staged_i8, + }; + use crate::CompressInto; + use crate::algorithms::Transform; + use crate::algorithms::transforms::NullTransform; + use crate::alloc::ScopedAllocator; + use crate::minmax::{MinMaxCompensation, MinMaxMeta, MinMaxQuantizer}; + use crate::multi_vector::distance::{MaxSim, QueryMatRef}; + use crate::multi_vector::{BlockTransposed, Defaulted, Mat, MatRef, Standard}; + use crate::num::Positive; + use diskann_vector::DistanceFunctionMut; + + const NBITS: usize = 4; + + fn quantizer(dim: usize) -> MinMaxQuantizer { + MinMaxQuantizer::new( + Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())), + Positive::new(1.0).unwrap(), + ) + } + + /// Pseudo-random f32 in roughly `[-1, 1]`, deterministic per `(seed, idx)`. + fn rnd(seed: u64, idx: usize) -> f32 { + let x = seed + .wrapping_mul(6364136223846793005) + .wrapping_add(idx as u64) + .wrapping_mul(1442695040888963407); + ((x >> 33) as f32 / (1u64 << 31) as f32) - 1.0 + } + + fn quantize(q: &MinMaxQuantizer, data: &[f32], n: usize, dim: usize) -> Mat> { + let input = MatRef::new(Standard::::new(n, dim).unwrap(), data).unwrap(); + let mut out: Mat> = Mat::new(MinMaxMeta::new(n, dim), Defaulted).unwrap(); + q.compress_into(input, out.reborrow_mut()).unwrap(); + out + } + + /// Extract `(codes_u8 [nv × padded_dim, zero-padded], meta [nv])` from a + /// quantized doc matrix — the minimal doc-side SoA. + fn doc_soa(mat: &Mat>, dim: usize) -> (Vec, Vec) { + let nv = mat.num_vectors(); + let padded_dim = dim.next_multiple_of(2); + let mut codes = vec![0u8; nv * padded_dim]; + let mut meta = Vec::with_capacity(nv); + for r in 0..nv { + let row = mat.get_row(r).unwrap(); + for j in 0..dim { + codes[r * padded_dim + j] = row.vector().get(j).unwrap() as u8; + } + meta.push(row.meta()); + } + (codes, meta) + } + + /// Extract `(codes_i16 [nq × dim], meta padded to padded_nrows)` for the query + /// side — i16 widening + row padding feeds `BlockTransposed`. + fn query_arrays( + mat: &Mat>, + dim: usize, + padded_nrows: usize, + ) -> (Vec, Vec) { + let nq = mat.num_vectors(); + let mut codes = vec![0i16; nq * dim]; + let mut meta = vec![MinMaxCompensation::default(); padded_nrows]; + for r in 0..nq { + let row = mat.get_row(r).unwrap(); + for j in 0..dim { + codes[r * dim + j] = i16::from(row.vector().get(j).unwrap() as u8); + } + meta[r] = row.meta(); + } + (codes, meta) + } + + /// (nq, nd, dim): every B-remainder class (`nd ∈ {1,5,6,7,8}`), A-panel + /// remainder (`17`), multi-tile B (`1250`), and `dim ∈ {64,128,256}`. + const CASES: &[(usize, usize, usize)] = &[ + (1, 1, 64), + (1, 5, 64), + (5, 1, 128), + (16, 4, 64), + (16, 5, 128), + (16, 6, 64), + (16, 7, 256), + (16, 8, 128), + (17, 9, 64), + (32, 16, 256), + (8, 1250, 128), + (64, 1250, 64), + // Odd dims exercise the even-K contract (padded_dim = dim + 1, the trailing + // column zero-padded on both sides), end-to-end against the reference. + (5, 3, 63), + (17, 9, 65), + (8, 33, 127), + (16, 7, 1), + ]; + + /// The public quantized staged path must match the scalar MinMax `MaxSim` + /// reference within tolerance (Stage A's integer dot is exact; only Stage B's + /// f32 accumulation order differs from the reference's). Exercises the full + /// public API: [`QuantStagedQuery`]/[`QuantStagedDocs`] build + compute. + #[test] + fn staged_i8_matches_minmax_reference() { + if V3::new_checked().is_none() { + return; // No AVX2 on this host. + } + + for &(nq, nd, dim) in CASES { + let q_data: Vec = (0..nq * dim).map(|i| rnd(1, i)).collect(); + let d_data: Vec = (0..nd * dim).map(|i| rnd(2, i)).collect(); + + // ── Path A: the public quantized staged kernel. ── + let q_f32 = MatRef::new(Standard::::new(nq, dim).unwrap(), &q_data).unwrap(); + let d_f32 = MatRef::new(Standard::::new(nd, dim).unwrap(), &d_data).unwrap(); + let mut query = QuantStagedQuery::build(q_f32).unwrap(); + let docs = QuantStagedDocs::build(d_f32); + let mut got = vec![0.0f32; nq]; + query.compute_max_sim(&docs, &mut got); + + // ── Path B: scalar MinMax MaxSim reference (identical quantization). ── + let q = quantizer(dim); + let q_mat = quantize(&q, &q_data, nq, dim); + let d_mat = quantize(&q, &d_data, nd, dim); + let query_ref: QueryMatRef<_> = q_mat.as_view().into(); + let mut ref_scores = vec![0.0f32; nq]; + MaxSim::new(&mut ref_scores).evaluate(query_ref, d_mat.as_view()); + + for i in 0..nq { + assert!( + (got[i] - ref_scores[i]).abs() <= 1e-4 * ref_scores[i].abs().max(1.0), + "({nq},{nd},{dim}) row {i}: staged-i8 min-dist {} != reference {}", + got[i], + ref_scores[i], + ); + } + } + } + + /// Reusing a single [`QuantStagedQuery`] across multiple `compute_max_sim` + /// calls (different doc sets / counts) must give correct results every time — + /// the regression guard for the [`ResettableArena`](super::super::arena::ResettableArena) + /// reset path: each call rewinds and re-fills the shared scratch, so a stale + /// or aliased buffer would corrupt the second/third call. + #[test] + fn staged_i8_arena_reuse_across_calls() { + if V3::new_checked().is_none() { + return; // No AVX2 on this host. + } + + const NQ: usize = 17; // exercises the A-panel row padding (17 -> 32) + const DIM: usize = 128; + let q_data: Vec = (0..NQ * DIM).map(|i| rnd(5, i)).collect(); + let q_f32 = MatRef::new(Standard::::new(NQ, DIM).unwrap(), &q_data).unwrap(); + let mut query = QuantStagedQuery::build(q_f32).unwrap(); + + let quant = quantizer(DIM); + let q_mat = quantize(&quant, &q_data, NQ, DIM); + + // Distinct doc counts (multi-tile, single panel, remainder) reusing the + // same query — the arena is reset, never reallocated, between calls. + for (call, &nd) in [251usize, 3, 64, 1].iter().enumerate() { + let d_data: Vec = (0..nd * DIM).map(|i| rnd(6 + call as u64, i)).collect(); + let d_f32 = MatRef::new(Standard::::new(nd, DIM).unwrap(), &d_data).unwrap(); + let docs = QuantStagedDocs::build(d_f32); + + let mut got = vec![0.0f32; NQ]; + query.compute_max_sim(&docs, &mut got); + + let d_mat = quantize(&quant, &d_data, nd, DIM); + let query_ref: QueryMatRef<_> = q_mat.as_view().into(); + let mut ref_scores = vec![0.0f32; NQ]; + MaxSim::new(&mut ref_scores).evaluate(query_ref, d_mat.as_view()); + + for i in 0..NQ { + assert!( + (got[i] - ref_scores[i]).abs() <= 1e-4 * ref_scores[i].abs().max(1.0), + "call {call} (nd={nd}) row {i}: reused staged-i8 {} != reference {}", + got[i], + ref_scores[i], + ); + } + } + } + + /// Isolate Stage A: the raw `i32` partial it stores must equal the brute-force + /// integer code dot `⟨codes_q, codes_d⟩` exactly (no float, no metadata). + #[test] + fn stage_a_integer_dot_exact() { + let Some(arch) = V3::new_checked() else { + return; + }; + + for &dim in &[64usize, 128, 130, 256] { + let padded_dim = dim.next_multiple_of(2); + let q = quantizer(dim); + // Exactly one A-panel (16 rows) × one B-panel (4 cols). + let q_data: Vec = (0..16 * dim).map(|i| rnd(3, i)).collect(); + let d_data: Vec = (0..4 * dim).map(|i| rnd(4, i)).collect(); + let q_mat = quantize(&q, &q_data, 16, dim); + let d_mat = quantize(&q, &d_data, 4, dim); + + let (d_codes, _) = doc_soa(&d_mat, dim); + let q_i16 = { + let bt = BlockTransposed::::new(16, dim); + let (c, _) = query_arrays(&q_mat, dim, bt.padded_nrows()); + c + }; + let q_mat_view = MatRef::new(Standard::::new(16, dim).unwrap(), &q_i16).unwrap(); + let bt = BlockTransposed::::from_matrix_view(q_mat_view.as_matrix_view()); + + let mut partial = vec![0i32; 16 * 4]; + // SAFETY: `bt` has exactly one block (16 rows) at `as_ptr()`; `d_codes` + // is 4 rows × padded_dim u8; `partial` is 4 cols × 16 i32 at stride 16. + unsafe { + int_store_microkernel::<4>( + arch, + bt.as_ptr(), + d_codes.as_ptr(), + padded_dim, + partial.as_mut_ptr(), + 16, + ); + } + + // Brute-force ⟨codes⟩ over the logical dim (codes are u8 ∈ [0,15]). + for i in 0..16 { + let qr = q_mat.get_row(i).unwrap(); + for jcol in 0..4 { + let dr = d_mat.get_row(jcol).unwrap(); + let expect: i32 = (0..dim) + .map(|d| { + i32::from(qr.vector().get(d).unwrap() as u8) + * i32::from(dr.vector().get(d).unwrap() as u8) + }) + .sum(); + assert_eq!( + partial[jcol * 16 + i], + expect, + "dim={dim} A-major partial[col {jcol}, row {i}] != brute force" + ); + } + } + } + } + + /// Drive the internal entry with a deliberately tiny cache budget so the + /// planner clamps to one A-panel and one B-panel per tile. With `nq > 16` and + /// `nd > 4` this forces **multiple A-tiles and multiple B-tiles**, exercising + /// the cross-tile `a_row_offset`/`b_row_offset` carry that the default-budget + /// reference test (one giant A-tile) never reaches — yet still matching the + /// scalar reference. + #[test] + fn staged_i8_multi_tile_tiny_budget() { + let Some(arch) = V3::new_checked() else { + return; + }; + + // l2_a / l1_b of 1 clamp `a_panels_per_tile` / `b_panels_per_tile` to 1 + // (both `.max(1)` in `StagedPlan::new`): a_tile_rows = 16, b_tile_rows = 4. + let budget = TileBudget { l2_a: 1, l1_b: 1 }; + + for &(nq, nd, dim) in &[(48usize, 22usize, 64usize), (33, 37, 128), (35, 19, 65)] { + let q = quantizer(dim); + let q_data: Vec = (0..nq * dim).map(|i| rnd(5, i)).collect(); + let d_data: Vec = (0..nd * dim).map(|i| rnd(6, i)).collect(); + let q_mat = quantize(&q, &q_data, nq, dim); + let d_mat = quantize(&q, &d_data, nd, dim); + + let padded_dim = dim.next_multiple_of(2); + let (d_codes, d_meta) = doc_soa(&d_mat, dim); + let doc = MatRef::new(Standard::::new(nd, padded_dim).unwrap(), &d_codes).unwrap(); + + let q_i16 = { + let probe = BlockTransposed::::new(nq, dim); + let (c, _) = query_arrays(&q_mat, dim, probe.padded_nrows()); + c + }; + let q_view = MatRef::new(Standard::::new(nq, dim).unwrap(), &q_i16).unwrap(); + let query_bt = BlockTransposed::::from_matrix_view(q_view.as_matrix_view()); + let (_, q_meta) = query_arrays(&q_mat, dim, query_bt.padded_nrows()); + let post = MinMaxPostprocess::new(&q_meta, &d_meta, dim); + + let mut state = vec![f32::MIN; query_bt.padded_nrows()]; + max_ip_kernel_staged_i8( + arch, + query_bt.as_view(), + doc, + &post, + &mut state, + ScopedAllocator::global(), + budget, + ); + + let query_ref: QueryMatRef<_> = q_mat.as_view().into(); + let mut ref_scores = vec![0.0f32; nq]; + MaxSim::new(&mut ref_scores).evaluate(query_ref, d_mat.as_view()); + + for i in 0..nq { + let got = -state[i]; + assert!( + (got - ref_scores[i]).abs() <= 1e-4 * ref_scores[i].abs().max(1.0), + "({nq},{nd},{dim}) row {i}: tiny-budget staged-i8 {got} != reference {}", + ref_scores[i], + ); + } + } + } +} diff --git a/diskann-quantization/src/multi_vector/distance/kernels/staged/maxsim.rs b/diskann-quantization/src/multi_vector/distance/kernels/staged/maxsim.rs new file mode 100644 index 000000000..18cdfa88c --- /dev/null +++ b/diskann-quantization/src/multi_vector/distance/kernels/staged/maxsim.rs @@ -0,0 +1,155 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +//! Stage B / Stage C impls for MaxSim, plus the owned reset-arena scratch for the +//! `MaxSimKernel` (f32) path. +//! +//! The Stage-A kernel and the V3 entry point live in [`super::v3`]; the +//! Stage-C reducer's SIMD fold is V3-specific and also lives there. + +use core::marker::PhantomData; + +use diskann_wide::Architecture; + +use super::super::TileBudget; +use super::arena::ResettableArena; +use super::{FoldCtx, Postprocess}; +use crate::alloc::ScopedAllocator; + +// ── Stage B: identity ──────────────────────────────────────────── + +/// Identity postprocess: the raw inner product (`Acc`) *is* the `Score`. +/// +/// Reports [`scratch_len`](Postprocess::scratch_len) 0 and +/// [`apply`](Postprocess::apply) returns its input pointer unchanged, so the +/// driver folds `partial_buf` directly — no `scored_buf`, no extra pass, no +/// boolean flag. (Mirrors [`ConvertTo`](super::super::layouts::ConvertTo)'s +/// zero-cost identity blanket impl.) +pub(super) struct Identity(PhantomData); + +impl Identity { + pub(super) fn new() -> Self { + Self(PhantomData) + } +} + +// SAFETY: identity reads nothing beyond `acc`, writes nothing (`scratch_len` +// is 0), and returns exactly `acc` — `Score == Acc == T`, already A-major at the +// fixed output stride `a_panel`. The returned pointer carries the caller's +// validity for `acc` unchanged. +unsafe impl Postprocess for Identity { + type Acc = T; + type Score = T; + + #[inline(always)] + fn scratch_len(&self, _a_panel: usize, _max_b_cols: usize) -> usize { + 0 + } + + #[inline(always)] + unsafe fn apply(&self, _scratch: *mut T, _arch: A, acc: *const T, _ctx: FoldCtx) -> *const T { + acc + } +} + +/// MaxSim reducer: per-A-row running maximum of the inner products. The +/// `Reducer` impl (a register-resident `max_simd` sweep) is V3-specific and +/// lives in [`super::v3`]. +pub(super) struct MaxReducer; + +// ── Dispatch carrier ───────────────────────────────────────────── + +/// `state` (the output running-reduction) + `alloc` (the caller's allocator the +/// driver carves its internal `partial`/`scored` scratch from) bundled to cross +/// the [`Target3`](diskann_wide::arch::Target3) dispatch boundary. Only the +/// allocator crosses here — the driver allocates the scratch buffers itself, so +/// the caller hands in nothing but `state` and `alloc`. +pub(crate) struct StagedRun<'a> { + pub(crate) state: &'a mut [f32], + pub(crate) alloc: ScopedAllocator<'a>, +} + +// ── Owned reset-arena scratch for the f32 `MaxSimKernel` path ───── + +/// Per-kernel reusable scratch for the staged f32 [`MaxSimKernel`](super::super::MaxSimKernel) +/// path: the running-reduction `state` output plus a [`ResettableArena`] backing +/// the driver's transient `partial`/`scored`. Owned by `PreparedStaged` (behind a +/// `RefCell`, since `compute_max_sim` is `&self`) so steady-state calls allocate +/// nothing — the arena is reset, not reallocated, each call. +/// +/// This is the f32 counterpart to the i8 path's `QuantStagedQuery`-owned arena; +/// it lives here (not in `factory`) because sizing reads [`TileBudget`], which is +/// private to the `kernels` module tree. +#[derive(Debug)] +pub(crate) struct F32StagedScratch { + state: Vec, + arena: ResettableArena, +} + +impl F32StagedScratch { + /// Build scratch for a query of `padded` rows. The arena is sized once to the + /// provable `2·l1_b` ceiling (`StagedPlan` co-budgets `partial`/`scored` each + /// `≤ l1_b`), so it never needs per-shape sizing or reallocation. + #[allow(clippy::expect_used)] // POC: 72 KB arena, OOM is not a recoverable case here. + pub(crate) fn new(padded: usize) -> Self { + let arena_bytes = 2 * TileBudget::default().l1_b + 4096; + Self { + state: vec![f32::MIN; padded], + arena: ResettableArena::with_capacity(arena_bytes).expect("f32 staged arena"), + } + } + + /// Reset the arena, ensure `state` covers `padded` rows, then run `f` with the + /// `state` slice (the driver re-initialises it) and a `ScopedAllocator` over + /// the arena. + pub(crate) fn run( + &mut self, + padded: usize, + f: impl FnOnce(&mut [f32], ScopedAllocator<'_>) -> R, + ) -> R { + if self.state.len() < padded { + self.state.resize(padded, f32::MIN); + } + // Split the borrow so `state` (mut) and `arena` (shared, via the allocator) + // are simultaneously live as disjoint fields. + let Self { state, arena } = self; + arena.reset(); + f(&mut state[..padded], ScopedAllocator::new(arena)) + } +} + +#[cfg(test)] +mod tests { + use diskann_wide::arch::Scalar; + + use super::*; + + /// The identity postprocess returns its input pointer unchanged (the + /// zero-cost identity), so the driver folds `partial_buf` directly. + #[test] + fn identity_returns_source_pointer() { + let acc = [1.0f32, 2.0, 3.0, 4.0]; + let id = Identity::::new(); + let ctx = FoldCtx { + a_panel: 2, + valid_b_cols: 2, + b_stride: 2, + a_row_offset: 0, + b_row_offset: 0, + }; + // SAFETY: identity ignores `scratch` (its `scratch_len` is 0) and returns + // its input pointer unchanged. UFCS pins `A = Scalar` (the impl is blanket + // over `A`); the real driver pins `A = V3` via turbofish, so production + // never needs this. + let out = unsafe { + as Postprocess>::apply( + &id, + core::ptr::null_mut(), + Scalar::new(), + acc.as_ptr(), + ctx, + ) + }; + assert_eq!(out, acc.as_ptr()); + } +} diff --git a/diskann-quantization/src/multi_vector/distance/kernels/staged/mod.rs b/diskann-quantization/src/multi_vector/distance/kernels/staged/mod.rs new file mode 100644 index 000000000..e151ec1cc --- /dev/null +++ b/diskann-quantization/src/multi_vector/distance/kernels/staged/mod.rs @@ -0,0 +1,428 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +//! Experimental *staged* multi-vector distance kernel. +//! +//! The production kernel (`super::tiled_reduce` + `super::f32`) fuses three +//! concerns into one micro-kernel epilogue: inner-product accumulation, +//! cross-row reduction, and merge into the per-A-row score scratch. This module +//! is a parallel, *separately selectable* kernel that splits those concerns into +//! three independently-pluggable stages so future work (quantized distances, +//! other reductions) can reuse the tiling loop without forking the micro-kernel: +//! +//! * **Stage A — [`StagedKernel`]**: pure SIMD math. Writes a raw `Acc` block +//! (`f32` here) into a per-A-panel `partial_buf`; no reduction, no merge. +//! * **Stage B — [`Postprocess`]**: `Acc` → `Score`, modeled on +//! [`ConvertTo`](super::layouts::ConvertTo). `apply` returns a read pointer and +//! the identity impl returns its input unchanged (reporting `scratch_len` 0), so +//! the driver runs it uniformly and identity stays zero-cost — no extra memory +//! pass, no boolean flag. +//! * **Stage C — [`Reducer`]**: owns the per-A-row `State` and folds `Score` +//! blocks into it. +//! +//! The traits keep raw-pointer methods (no generic methods) so a future +//! `&dyn Postprocess` / `&dyn Reducer` switched at the per-(A-panel, B-tile) +//! boundary stays possible without a monomorphization blow-up; we keep them +//! static-generic for now. +//! +//! Scope: single K-segment (no fractured-K), f32 + V3 (AVX2/FMA) only — an +//! apples-to-apples A/B against the fused V3 kernel. + +use diskann_wide::Architecture; + +use super::TileBudget; +use super::layouts::Layout; + +pub(super) mod arena; +pub(super) mod driver; +pub(super) mod i8; +pub(super) mod maxsim; +pub(super) mod v3; + +pub(crate) use maxsim::{F32StagedScratch, StagedRun}; +pub(crate) use v3::StagedF32Kernel; +// Public POC entry for the quantized (4-bit MinMax) staged kernel. +pub use i8::{QuantStagedDocs, QuantStagedQuery}; + +// ── Stage A: kernel ────────────────────────────────────────────── + +/// Stage A micro-kernel. Computes an `A_PANEL × B_PANEL` block of raw `Acc` +/// accumulators and **writes them out** to a partial buffer — unlike +/// [`super::Kernel`], it performs no cross-row reduction or scratch merge. +/// +/// # Safety +/// +/// Implementors must respect the per-method pointer contracts. +pub(super) unsafe trait StagedKernel { + /// Layout consumed by the A (left / query) side. + type Left: Layout; + /// Layout consumed by the B (right / document) side. + type Right: Layout; + /// Raw accumulator element written into `partial_buf` (`f32` for MaxSim). + type Acc: Copy; + + /// A rows processed per invocation (= the query `BlockTransposed` GROUP). + const A_PANEL: usize; + /// B rows processed per full invocation. + const B_PANEL: usize; + + /// Number of `Acc` elements the per-(A-panel, B-tile) `partial_buf` needs at + /// contraction dim `k` under `budget`. The single source of truth for the + /// partial size formula: the driver sizes its internal allocation from this, + /// derived from the panel geometry + the element sizes the kernel already + /// knows. + fn partial_len(k: usize, budget: TileBudget) -> usize { + let a_elem = core::mem::size_of::<::Element>(); + let b_elem = core::mem::size_of::<::Element>(); + let acc = core::mem::size_of::(); + StagedPlan::new( + k * a_elem, + k * b_elem, + Self::A_PANEL, + Self::B_PANEL, + acc, + budget, + ) + .partial_len(Self::A_PANEL, Self::B_PANEL) + } + + /// Write a full `A_PANEL × B_PANEL` block into `partial`. + /// + /// `partial` points at the first B-column of this panel; column `j` + /// (`0..B_PANEL`) and its `A_PANEL` rows occupy `partial[j*partial_b_stride + /// ..][..A_PANEL]` (A-major: `partial_b_stride == A_PANEL`). + /// + /// # Safety + /// + /// * `a` valid for `A_PANEL * k` `Left::Element`. + /// * `b` valid for `B_PANEL * k` `Right::Element`. + /// * `partial` valid for `B_PANEL` columns of `A_PANEL` `Acc` at stride + /// `partial_b_stride`. + unsafe fn full_panel( + arch: A, + a: *const ::Element, + b: *const ::Element, + k: usize, + partial: *mut Self::Acc, + partial_b_stride: usize, + ); + + /// Like [`Self::full_panel`] but writes only `remainder` (`1..B_PANEL`) + /// B-columns. + /// + /// # Safety + /// + /// As [`Self::full_panel`], with `b` valid for `remainder * k` + /// `Right::Element` and only `remainder` columns written. + unsafe fn partial_panel( + arch: A, + remainder: usize, + a: *const ::Element, + b: *const ::Element, + k: usize, + partial: *mut Self::Acc, + partial_b_stride: usize, + ); +} + +// ── Stage B: postprocess ───────────────────────────────────────── + +/// Per-(A-panel, B-tile) context passed to [`Postprocess::apply`]. +/// +/// All fields are cheap `Copy` scalars the driver already tracks. The identity +/// postprocess ignores them (they vanish after inlining); a metadata-bearing +/// postprocess (e.g. the quantized path) uses the global +/// `a_row_offset`/`b_row_offset` to index its per-vector metadata. +#[derive(Debug, Clone, Copy)] +pub(super) struct FoldCtx { + /// Rows in this A-panel (`== A_PANEL`). + pub(super) a_panel: usize, + /// Valid B columns in this block (`≤` the tile width). + pub(super) valid_b_cols: usize, + /// `Acc` column stride within the partial block (`== a_panel`). + pub(super) b_stride: usize, + /// Global index of this A-panel's first row. + pub(super) a_row_offset: usize, + /// Global index of this B-tile's first row. + pub(super) b_row_offset: usize, +} + +/// Stage B: convert one A-major block of raw `Acc` accumulators into finished +/// `Score`s, returning a read pointer to the scores. +/// +/// The driver sizes the staging region from [`scratch_len`](Self::scratch_len) +/// and allocates it (from the caller's allocator), then hands it to +/// [`apply`](Self::apply) as a raw `*mut Score`. The identity impl +/// ([`Identity`](maxsim::Identity)) reports `scratch_len == 0` and returns its +/// input pointer unchanged, so the driver runs one uniform path — no boolean, no +/// branch — and the identity case is zero-cost. A non-identity impl — e.g. the +/// quantized [`MinMaxPostprocess`](i8::MinMaxPostprocess), which turns raw `i32` +/// integer dot products into f32 MinMax scores using its own captured per-vector +/// metadata — writes into `scratch` and returns a pointer into it. +/// +/// # Safety +/// +/// Implementors must ensure that [`apply`](Self::apply): +/// - reads at most `ctx.valid_b_cols` columns of `ctx.a_panel` `Acc` from `acc` +/// at column stride `ctx.b_stride` (never the stale padded remainder columns); +/// - writes only within the `scratch` region it was given; +/// - returns a `*const Score` valid for `ctx.valid_b_cols` columns of +/// `ctx.a_panel` `Score` at the **fixed output column stride `ctx.a_panel`**. +/// (Identity returns `acc`, already at stride `a_panel`.) +pub(super) unsafe trait Postprocess { + /// Raw accumulator type produced by Stage A. + type Acc: Copy; + /// Finished score type consumed by Stage C. + type Score: Copy; + + /// Number of `Score` elements [`apply`](Self::apply) needs for one A-panel + /// against a B-tile of up to `max_b_cols` columns. `0` for the identity + /// postprocess (no staging — `apply` returns `acc`); the driver allocates + /// exactly this many `Score`s from the caller's allocator and passes the + /// region to `apply`. + fn scratch_len(&self, a_panel: usize, max_b_cols: usize) -> usize; + + /// Convert the `ctx.a_panel × ctx.valid_b_cols` A-major `Acc` block at `acc` + /// (column stride `ctx.b_stride`) into `Score`s, returning a read pointer. + /// Output is A-major at the fixed column stride `ctx.a_panel`; the identity + /// impl returns `acc` unchanged. Metadata-bearing impls index their own + /// per-vector metadata by `ctx.a_row_offset` / `ctx.b_row_offset`. + /// + /// # Safety + /// + /// * `acc` is valid for `ctx.valid_b_cols` columns of `ctx.a_panel` `Acc` at + /// stride `ctx.b_stride`. + /// * `scratch` is valid+writable for `scratch_len(ctx.a_panel, max_b_cols)` + /// `Score` with `max_b_cols ≥ ctx.valid_b_cols` (dangling when that is `0`). + unsafe fn apply( + &self, + scratch: *mut Self::Score, + arch: A, + acc: *const Self::Acc, + ctx: FoldCtx, + ) -> *const Self::Score; +} + +// ── Stage C: reducer ───────────────────────────────────────────── + +/// Stage C owns the per-A-row reduction `State` and folds `Score` blocks into +/// it. `Max` here; richer `State` shapes (argmax `(f32,u32)`, top-k) and an +/// `Output`/`finalize` step are follow-on work. +pub(super) trait Reducer { + /// Score element folded in (matches [`Postprocess::Score`]). + type Score: Copy; + /// Per-A-row running state (the score scratch element). + type State: Copy; + + /// Identity state for an A-row before any B-rows are seen. + fn init() -> Self::State; + + /// Fold an `A_PANEL × valid_b_cols` block of `Score` (read from + /// `partial_buf`) into `state[0..a_panel]`, in place. + /// + /// Column `c` (`0..valid_b_cols`), row `i` (`0..a_panel`) is at + /// `scores[c*b_stride + i]`. Only `valid_b_cols` columns are read — the + /// padded remainder columns hold stale data and **must not** be folded. + /// + /// # Safety + /// + /// * `state` valid+writable for `a_panel` `State`. + /// * `scores` valid for `valid_b_cols` columns of `a_panel` `Score` at + /// stride `b_stride`. + unsafe fn fold_block( + arch: A, + state: *mut Self::State, + scores: *const Self::Score, + a_panel: usize, + valid_b_cols: usize, + b_stride: usize, + ); +} + +// ── Stage conversion: StagedConvert ────────────────────────────── + +/// Staged-local tile conversion from layout `Self` to layout `To` — the staged +/// path's self-contained replacement for the shared +/// [`ConvertTo`](super::layouts::ConvertTo). Same role (convert a tile of source +/// data into the kernel's element type), but inverted ownership: instead of +/// owning a `Buffer`, the impl reports a [`scratch_len`](Self::scratch_len) and +/// the **driver** allocates that staging region from the caller's allocator and +/// hands it to [`convert`](Self::convert). The blanket identity impl reports `0` +/// and returns `src` unchanged, so identity conversions cost nothing and the +/// staged driver never touches the shared `ConvertTo` machinery. +/// +/// # Safety +/// +/// Implementors must ensure [`convert`](Self::convert) reads at most `rows * k` +/// source elements, writes only within the `scratch` region it was given, and +/// returns a pointer valid for `rows * k` `To::Element`. +pub(super) unsafe trait StagedConvert: Layout { + /// Number of `To::Element` the driver must allocate to convert up to + /// `max_tile_rows × k`. `0` for identity (no conversion — `convert` returns + /// `src`, ignoring `scratch`). + fn scratch_len(&self, max_tile_rows: usize, k: usize) -> usize; + + /// Convert `rows × k` `Self::Element` at `src` into `To::Element`, writing + /// into `scratch` (the driver-allocated region of `scratch_len(..)`), and + /// returning a read pointer. The identity impl returns `src` unchanged. + /// + /// # Safety + /// + /// * `src` points to `rows * k` valid `Self::Element`. + /// * `scratch` is valid+writable for `scratch_len(max_tile_rows, k)` + /// `To::Element` with `max_tile_rows ≥ rows` (dangling when that is `0`). + unsafe fn convert( + &self, + scratch: *mut To::Element, + arch: A, + src: *const Self::Element, + rows: usize, + k: usize, + ) -> *const To::Element; +} + +/// Identity conversion: every layout converts to itself at zero cost (no +/// scratch, returns `src`). Mirrors the shared `ConvertTo` blanket identity. +// SAFETY: identity reads nothing beyond `src`, writes nothing (`scratch_len` is +// 0), and returns exactly `src`, valid for the caller's lifetime. +unsafe impl StagedConvert for L { + fn scratch_len(&self, _max_tile_rows: usize, _k: usize) -> usize { + 0 + } + + unsafe fn convert( + &self, + _scratch: *mut L::Element, + _arch: A, + src: *const L::Element, + _rows: usize, + _k: usize, + ) -> *const L::Element { + src + } +} + +// ── Planner ────────────────────────────────────────────────────── + +/// Tile-panel counts for the staged loop. +/// +/// Encodes the **partial-buffer granularity** decision: the partial buffer is +/// always **one A-panel** (`P_a = 1`) wide in the A direction, against **as many +/// B-panels as co-fit L1** (`P_b = b_panels_per_tile`) in the B direction. +/// +/// `P_a = 1` because extra A-panels cut no B reads (B is re-streamed per A-panel +/// regardless) — they only enlarge `partial_buf`. `P_b = co-fit` rather than +/// `1×1` because total partial traffic is invariant to the fold granularity, so +/// the widest fold minimizes fold-call / `state`-reload overhead, maximizes the +/// SIMD sweep, and keeps Stage A (kernel) and Stage C (reduce) as contiguous +/// non-interleaved phases. See `docs/staged_multi_vector_kernel.md` §5. +#[derive(Debug, Clone, Copy)] +pub(super) struct StagedPlan { + pub(super) a_panels_per_tile: usize, + pub(super) b_panels_per_tile: usize, +} + +impl StagedPlan { + /// Choose `a_panels_per_tile` / `b_panels_per_tile` from the cache budgets, + /// the panel sizes, and the partial-buffer footprint. + /// + /// **L2** holds the A-tile (it is reused across every B-tile): + /// + /// ```text + /// a_panels_per_tile · A_PANEL · a_row_bytes ≤ l2_a + /// ``` + /// + /// **L1** holds *three* things at once during Stage A / Stage C, so the + /// planner co-budgets all three against `l1_b` (a usable fraction of L1) + /// rather than letting each independently claim the whole budget: + /// + /// * one A micro-panel — `A_PANEL · a_row_bytes` (re-read per B-panel); + /// * the B-tile data — `B_TILE_ROWS · b_row_bytes` (re-read per A-panel); + /// * `partial_buf` — `A_PANEL · B_TILE_ROWS · acc_bytes`. + /// + /// Each B-row added to the tile therefore costs `b_row_bytes` of document + /// data **plus** `A_PANEL · acc_bytes` of partial scratch, so we keep the + /// largest `B_TILE_ROWS` satisfying + /// + /// ```text + /// A_PANEL·a_row_bytes + B_TILE_ROWS·(b_row_bytes + A_PANEL·acc_bytes) ≤ l1_b + /// ``` + /// + /// This bounds `partial_buf + B-tile` together. (For very large `k` the A + /// micro-panel alone can approach `l1_b`; then `b_panels_per_tile` clamps to + /// 1 and the A-panel is the limit — inherent, and identical to the fused + /// kernel's behaviour.) + pub(super) fn new( + a_row_bytes: usize, + b_row_bytes: usize, + a_panel: usize, + b_panel: usize, + acc_bytes: usize, + budget: TileBudget, + ) -> Self { + let a_row_bytes = a_row_bytes.max(1); + let b_row_bytes = b_row_bytes.max(1); + + // L2: the A-tile is reused across all B-tiles, so size it to L2. + let a_panels_per_tile = (budget.l2_a / (a_row_bytes * a_panel)).max(1); + + // L1: co-budget the A micro-panel + B-tile data + partial_buf. Each + // B-row costs its document data plus one A_PANEL-tall partial column. + let a_panel_bytes = a_panel * a_row_bytes; + let bytes_per_b_row = b_row_bytes + a_panel * acc_bytes; + let b_tile_budget = budget.l1_b.saturating_sub(a_panel_bytes); + let b_panels_per_tile = ((b_tile_budget / bytes_per_b_row) / b_panel).max(1); + + Self { + a_panels_per_tile, + b_panels_per_tile, + } + } + + /// `partial_buf` capacity (in `Acc` elements): one A-panel (`P_a = 1`) wide, + /// covering a full B-tile (`b_panels_per_tile · B_PANEL` rows). The driver + /// caps every B-tile at `b_tile_rows`, so this is the exact upper bound. + pub(super) fn partial_len(&self, a_panel: usize, b_panel: usize) -> usize { + a_panel * self.b_panels_per_tile * b_panel + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // Mirror of `L1_CACHE` in `super::super::TileBudget::default`. + const L1_CACHE_BYTES: usize = 48_000; + + /// The co-budget must keep the inner-loop L1 working set — one A + /// micro-panel + the B-tile data + `partial_buf` — within real L1 for every + /// realistic `k`. The previous design budgeted `partial_buf` and the B-tile + /// independently against `l1_b` and overflowed at small/moderate `k` (e.g. + /// k=16 placed ~70 KB into a 48 KB L1); this pins the fix. + #[test] + fn l1_working_set_fits_for_all_k() { + const A_PANEL: usize = 16; + const B_PANEL: usize = 4; + const ACC: usize = 4; // f32 + + // k up to 512: beyond ~768 the A micro-panel alone exceeds L1, which is + // inherent (the fused kernel hits the same wall) and not a planner bug. + for &k in &[1usize, 2, 4, 8, 16, 32, 64, 128, 256, 384, 512] { + let row = k * 4; // f32 element + let plan = StagedPlan::new(row, row, A_PANEL, B_PANEL, ACC, TileBudget::default()); + let b_tile_rows = plan.b_panels_per_tile * B_PANEL; + + let a_panel_bytes = A_PANEL * row; + let b_data_bytes = b_tile_rows * row; + let partial_bytes = A_PANEL * b_tile_rows * ACC; + let working_set = a_panel_bytes + b_data_bytes + partial_bytes; + + assert!( + working_set <= L1_CACHE_BYTES, + "k={k}: L1 working set {working_set} B (a_panel={a_panel_bytes}, \ + b_data={b_data_bytes}, partial={partial_bytes}) exceeds L1 {L1_CACHE_BYTES} B", + ); + assert!(plan.a_panels_per_tile >= 1 && plan.b_panels_per_tile >= 1); + } + } +} diff --git a/diskann-quantization/src/multi_vector/distance/kernels/staged/v3.rs b/diskann-quantization/src/multi_vector/distance/kernels/staged/v3.rs new file mode 100644 index 000000000..e09328aa9 --- /dev/null +++ b/diskann-quantization/src/multi_vector/distance/kernels/staged/v3.rs @@ -0,0 +1,395 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +//! V3 (AVX2+FMA) staged path: the store-out micro-kernel (Stage A), the +//! register-resident `max_simd` reducer (Stage C), and the f32 entry point. +//! +//! The Stage-A k-loop is byte-identical to the fused `super::super::f32::v3` +//! 16×4 kernel, so per-(query, doc) inner products are bit-identical. Only the +//! epilogue differs: instead of reducing the `UNROLL` accumulators and merging +//! into the score scratch, it **stores** them into `partial_buf` (A-major), +//! deferring the reduction to Stage C. +//! +//! The whole module is V3-specific; the experiment targets V3 only. + +use diskann_wide::arch::Target3; +use diskann_wide::arch::x86_64::V3; +use diskann_wide::{SIMDMinMax, SIMDMulAdd, SIMDVector}; + +use super::super::TileBudget; +use super::super::layouts::{self, DescribeLayout, Layout}; +use super::driver::tiled_reduce_staged; +use super::maxsim::{Identity, MaxReducer, StagedRun}; +use super::{Reducer, StagedConvert, StagedKernel}; +use crate::alloc::ScopedAllocator; +use crate::multi_vector::{BlockTransposedRef, MatRef, Standard}; + +diskann_wide::alias!(f32s = ::f32x8); + +/// Zero-sized Stage-A kernel marker for the f32 staged path with block size +/// `GROUP`. +pub(crate) struct StagedF32Kernel; + +// SAFETY: `full_panel`/`partial_panel` read A_PANEL(16) * k A elements and +// UNROLL * k B elements, and write UNROLL columns of A_PANEL(16) f32 into +// `partial` at stride `partial_b_stride` — all within the bounds the +// `StagedKernel` contract guarantees. +unsafe impl StagedKernel for StagedF32Kernel<16> { + type Left = layouts::BlockTransposed; + type Right = layouts::RowMajor; + type Acc = f32; + const A_PANEL: usize = 16; + const B_PANEL: usize = 4; + + #[inline(always)] + unsafe fn full_panel( + arch: V3, + a: *const f32, + b: *const f32, + k: usize, + partial: *mut f32, + partial_b_stride: usize, + ) { + // SAFETY: pointer validity per the `StagedKernel` contract. + unsafe { store_microkernel::<{ Self::B_PANEL }>(arch, a, b, k, partial, partial_b_stride) } + } + + #[inline(always)] + unsafe fn partial_panel( + arch: V3, + remainder: usize, + a: *const f32, + b: *const f32, + k: usize, + partial: *mut f32, + partial_b_stride: usize, + ) { + // SAFETY: pointer validity per the `StagedKernel` contract. + unsafe { + match remainder { + 1 => store_microkernel::<1>(arch, a, b, k, partial, partial_b_stride), + 2 => store_microkernel::<2>(arch, a, b, k, partial, partial_b_stride), + 3 => store_microkernel::<3>(arch, a, b, k, partial, partial_b_stride), + _ => unreachable!( + "unexpected remainder {remainder} for B_PANEL={}", + Self::B_PANEL + ), + } + } + } +} + +/// V3 store-out micro-kernel: 16 A-rows × `UNROLL` B-rows. +/// +/// The accumulation loop matches `super::super::f32::v3::f32_microkernel` +/// exactly (two `f32x8` register tiles, FMA, same splat/stride/unroll order). +/// The epilogue stores each B-column's 16 A-row accumulators into `partial` +/// (A-major: column `j` at `partial + j*b_stride`, as two `f32x8` halves). +/// +/// # Safety +/// +/// 1. `a_packed` points to `16 * k` contiguous `f32`. +/// 2. `b` points to `UNROLL` rows of `k` contiguous `f32`. +/// 3. `partial` is valid for `UNROLL` columns of 16 `f32` at stride `b_stride`. +#[inline(always)] +unsafe fn store_microkernel( + arch: V3, + a_packed: *const f32, + b: *const f32, + k: usize, + partial: *mut f32, + b_stride: usize, +) { + let mut p0 = [f32s::default(arch); UNROLL]; + let mut p1 = [f32s::default(arch); UNROLL]; + let offsets: [usize; UNROLL] = core::array::from_fn(|i| k * i); + + let a_stride = 2 * f32s::LANES; + let a_stride_half = f32s::LANES; + + for i in 0..k { + // SAFETY: preconditions 1 and 2; i < k and j < UNROLL. + unsafe { + let a0 = f32s::load_simd(arch, a_packed.add(a_stride * i)); + let a1 = f32s::load_simd(arch, a_packed.add(a_stride * i + a_stride_half)); + + for j in 0..UNROLL { + let bj = f32s::splat(arch, b.add(i + offsets[j]).read_unaligned()); + p0[j] = a0.mul_add_simd(bj, p0[j]); + p1[j] = a1.mul_add_simd(bj, p1[j]); + } + } + } + + for j in 0..UNROLL { + // SAFETY: precondition 3; column j occupies [j*b_stride, j*b_stride+16). + unsafe { + p0[j].store_simd(partial.add(j * b_stride)); + p1[j].store_simd(partial.add(j * b_stride + a_stride_half)); + } + } +} + +// ── Stage C: V3 SIMD max reducer ───────────────────────────────── + +impl Reducer for MaxReducer { + type Score = f32; + type State = f32; + + #[inline(always)] + fn init() -> f32 { + f32::MIN + } + + #[inline(always)] + unsafe fn fold_block( + arch: V3, + state: *mut f32, + scores: *const f32, + a_panel: usize, + valid_b_cols: usize, + b_stride: usize, + ) { + let lanes = f32s::LANES; + + // The V3 staged kernel always folds a full A_PANEL = 16 = 2*LANES block: + // two register-resident accumulators sweep the valid B-columns of + // `partial_buf` in a single pass — the same access pattern the fused + // kernel uses, just hoisted out of the inner B-loop. + debug_assert_eq!( + a_panel, + 2 * lanes, + "V3 MaxReducer expects A_PANEL == 2*LANES" + ); + + // SAFETY: `state` is writable for 16; `scores` is valid for `valid_b_cols` + // columns of 16 f32 at `b_stride`; only the valid columns are read (never + // the stale padded remainder). + unsafe { + let mut a0 = f32s::load_simd(arch, state); + let mut a1 = f32s::load_simd(arch, state.add(lanes)); + for c in 0..valid_b_cols { + let col = scores.add(c * b_stride); + a0 = a0.max_simd(f32s::load_simd(arch, col)); + a1 = a1.max_simd(f32s::load_simd(arch, col.add(lanes))); + } + a0.store_simd(state); + a1.store_simd(state.add(lanes)); + } + } +} + +// ── Entry point ────────────────────────────────────────────────── + +/// Compute per-A-row max inner product (block-transposed A query, row-major B +/// doc) into `state` via the staged pipeline. `state` (len ≥ `padded_nrows`) is +/// the caller's output, left holding the raw max-IP (the caller negates). +/// Transient scratch (`partial`, Stage-B region) is allocated internally from +/// `alloc` — the caller sizes nothing. +/// +/// # Panics +/// +/// Panics if `state.len() < a.padded_nrows()` or `a.ncols() != b.vector_dim()`. +pub(crate) fn max_ip_kernel_staged( + arch: V3, + a: BlockTransposedRef<'_, f32, GROUP>, + b: MatRef<'_, Standard>, + state: &mut [f32], + alloc: ScopedAllocator<'_>, + budget: TileBudget, +) where + StagedF32Kernel: StagedKernel, + layouts::BlockTransposed: StagedConvert as StagedKernel>::Left> + + Layout, + layouts::RowMajor: StagedConvert as StagedKernel>::Right> + + Layout, +{ + let padded = a.padded_nrows(); + if state.len() < padded || a.ncols() != b.vector_dim() { + max_ip_kernel_staged_panic(state.len(), padded, a.ncols(), b.vector_dim()); + } + + // A_PANEL must equal GROUP for block-transposed layout correctness. + const { assert!( as StagedKernel>::A_PANEL == GROUP) } + + let k = a.ncols(); + let b_nrows = b.num_vectors(); + + // Empty contraction: every IP is 0 ⇒ max-IP is 0. Callers guarantee + // b_nrows > 0 (the zero-doc case is short-circuited before reaching here). + if k == 0 { + state[..padded].fill(0.0); + return; + } + + let ca = a.layout(); + let cb = b.layout(); + let post = Identity::::new(); + + // SAFETY: + // - a.as_ptr() is valid for padded * k f32, and padded is a multiple of + // GROUP == A_PANEL (const-asserted above). + // - b.as_slice() is num_vectors * vector_dim contiguous f32. + // - state.len() >= padded (checked); the driver allocates its scratch from + // `alloc`. + unsafe { + tiled_reduce_staged::, Identity, MaxReducer, _, _>( + arch, + &ca, + &cb, + &post, + a.as_ptr(), + padded, + b.as_slice().as_ptr(), + b_nrows, + k, + &mut state[..padded], + alloc, + budget, + ); + } +} + +#[inline(never)] +#[cold] +#[allow(clippy::panic)] +fn max_ip_kernel_staged_panic(state_len: usize, padded: usize, a_ncols: usize, b_dim: usize) { + panic!( + "max_ip_kernel_staged: precondition failed: \ + state.len()={state_len} (expected >= {padded}), \ + a.ncols()={a_ncols}, b.vector_dim()={b_dim}" + ); +} + +// ── Dispatch glue ──────────────────────────────────────────────── + +impl + Target3, MatRef<'_, Standard>, StagedRun<'_>> + for StagedF32Kernel +where + StagedF32Kernel: StagedKernel, + layouts::BlockTransposed: + StagedConvert>::Left> + Layout, + layouts::RowMajor: + StagedConvert>::Right> + Layout, +{ + #[inline(always)] + fn run( + self, + arch: V3, + lhs: BlockTransposedRef<'_, f32, GROUP>, + rhs: MatRef<'_, Standard>, + scratch: StagedRun<'_>, + ) { + max_ip_kernel_staged( + arch, + lhs, + rhs, + scratch.state, + scratch.alloc, + TileBudget::default(), + ); + } +} + +#[cfg(test)] +mod tests { + use diskann_wide::arch::x86_64::V3; + + use super::super::super::TileBudget; + use super::super::super::f32::max_ip_kernel; + use super::max_ip_kernel_staged; + use crate::alloc::ScopedAllocator; + use crate::multi_vector::{BlockTransposed, MatRef, Standard}; + + // (a_nrows, b_nrows, dim): degenerate, zero-dim, zero-doc, prime k, A/B-panel + // boundaries and every B-remainder class for V3 (B_PANEL=4), plus multi-tile. + const CASES: &[(usize, usize, usize)] = &[ + (1, 1, 4), + (1, 5, 8), + (5, 1, 8), + (5, 3, 5), + (3, 2, 0), // zero dim + (3, 0, 4), // zero docs + (7, 7, 32), + (2, 3, 128), + (16, 4, 64), // one A-panel, no B remainder + (17, 4, 64), // A-panel remainder + (16, 5, 8), // B remainder = 1 + (16, 6, 32), // B remainder = 2 + (16, 7, 32), // B remainder = 3 + (16, 8, 32), + (32, 5, 16), + (48, 3, 16), + (8, 32, 128), + (64, 32, 128), + (32, 16, 256), + (64, 1250, 512), // multi-tile B + ]; + + fn naive(a: &[f32], a_nrows: usize, b: &[f32], b_nrows: usize, k: usize) -> Vec { + (0..a_nrows) + .map(|i| { + (0..b_nrows) + .map(|j| (0..k).map(|d| a[i * k + d] * b[j * k + d]).sum::()) + .fold(f32::MIN, f32::max) + }) + .collect() + } + + /// The staged kernel must be **bit-identical** to the fused V3 kernel (same + /// k-loop ⇒ same per-pair IP; `max` is order-independent), and within + /// tolerance of the naive reference. + #[test] + fn staged_matches_fused_v3() { + let Some(arch) = V3::new_checked() else { + // No AVX2/FMA on this host; the staged V3 path cannot run. + return; + }; + + for &(a_nrows, b_nrows, dim) in CASES { + let a_data: Vec = (0..a_nrows * dim).map(|i| (i % 13 + 1) as f32).collect(); + let b_data: Vec = (0..b_nrows * dim).map(|i| (i % 7 + 1) as f32).collect(); + + let a_mat = MatRef::new(Standard::new(a_nrows, dim).unwrap(), &a_data).unwrap(); + let a_bt = BlockTransposed::::from_matrix_view(a_mat.as_matrix_view()); + let b_mat = MatRef::new(Standard::new(b_nrows, dim).unwrap(), &b_data).unwrap(); + + let mut fused = vec![f32::MIN; a_bt.padded_nrows()]; + max_ip_kernel::( + arch, + a_bt.as_view(), + b_mat, + &mut fused, + TileBudget::default(), + ); + + let mut state = vec![f32::MIN; a_bt.padded_nrows()]; + max_ip_kernel_staged::<16>( + arch, + a_bt.as_view(), + b_mat, + &mut state, + ScopedAllocator::global(), + TileBudget::default(), + ); + + let expected = naive(&a_data, a_nrows, &b_data, b_nrows, dim); + for i in 0..a_nrows { + assert_eq!( + state[i].to_bits(), + fused[i].to_bits(), + "staged != fused at row {i} for ({a_nrows},{b_nrows},{dim}): staged={}, fused={}", + state[i], + fused[i], + ); + assert!( + (state[i] - expected[i]).abs() < 1e-6 * expected[i].abs().max(1.0), + "staged != naive at row {i} for ({a_nrows},{b_nrows},{dim}): staged={}, naive={}", + state[i], + expected[i], + ); + } + } + } +} diff --git a/diskann-quantization/src/multi_vector/distance/mod.rs b/diskann-quantization/src/multi_vector/distance/mod.rs index ef336161c..bdbc9e83c 100644 --- a/diskann-quantization/src/multi_vector/distance/mod.rs +++ b/diskann-quantization/src/multi_vector/distance/mod.rs @@ -53,3 +53,8 @@ pub use fallback::QueryMatRef; pub use isa::{MaxSimIsa, NotSupported}; pub use kernel::{BoxErase, Erase, MaxSimKernel}; pub use max_sim::{Chamfer, MaxSim, MaxSimError}; + +/// Standalone POC entry for the 4-bit MinMax *staged* multi-vector MaxSim kernel +/// (V3/AVX2 only) — not yet unified into [`MaxSimIsa`]/[`build_max_sim`]. +#[cfg(target_arch = "x86_64")] +pub use kernels::{QuantStagedDocs, QuantStagedQuery};