Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,77 @@ mod tests {
}
}

#[test]
fn test_new_checkpoint_creates_fresh_record() -> ANNResult<()> {
let temp_dir = tempdir()?;
let index_prefix = temp_dir
.path()
.join("fresh_index")
.to_str()
.unwrap()
.to_string();
// Two managers with the same prefix+identifier should see the same checkpoint state
let manager_a = CheckpointRecordManagerWithFileStorage::new(&index_prefix, 42);
let manager_b = CheckpointRecordManagerWithFileStorage::new(&index_prefix, 42);
assert_eq!(
manager_a.get_resumption_point(WorkStage::Start)?,
manager_b.get_resumption_point(WorkStage::Start)?
);
// A different identifier should be independent
let manager_c = CheckpointRecordManagerWithFileStorage::new(&index_prefix, 99);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unclear to me how this asserts independence from a and b like the comment states.

assert!(!manager_c.has_completed()?);
Ok(())
}

#[test]
fn test_has_completed_false_when_no_file() -> ANNResult<()> {
let temp_dir = tempdir()?;
let index_prefix = temp_dir
.path()
.join("nonexistent_index")
.to_str()
.unwrap()
.to_string();
let manager = CheckpointRecordManagerWithFileStorage::new(&index_prefix, 999);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears to be the identical test to the last thing above?

assert!(!manager.has_completed()?);
Ok(())
}

#[test]
fn test_mark_as_invalid() -> ANNResult<()> {
let temp_dir = tempdir()?;
let index_prefix = temp_dir
.path()
.join("test_invalid")
.to_str()
.unwrap()
.to_string();
let identifier = 77;

let mut manager = CheckpointRecordManagerWithFileStorage::new(&index_prefix, identifier);
// Advance to a later stage with some progress
manager.update(Progress::Completed, WorkStage::QuantizeFPV)?;
manager.update(Progress::Processed(42), WorkStage::InMemIndexBuild)?;

// Verify we can resume from progress=42
let manager2 = CheckpointRecordManagerWithFileStorage::new(&index_prefix, identifier);
assert_eq!(
manager2.get_resumption_point(WorkStage::QuantizeFPV)?,
Some(42)
);

// Mark as invalid - progress resets to 0 (is_valid=false => progress read as 0)
let mut manager3 = CheckpointRecordManagerWithFileStorage::new(&index_prefix, identifier);
manager3.mark_as_invalid()?;
assert_eq!(
manager3.get_resumption_point(WorkStage::QuantizeFPV)?,
Some(0)
);

clean_checkpoint_file(&index_prefix, identifier);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a big deal, but part of point of tempdir() is that the object returned cleans up after itself on drop. This kind of clean is only really needed if you were doing multiple things in the same tempdir and wanted to clean in between.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Ok(())
}

#[test]
fn test_checkpoint_manager_interruption_and_resumption() -> ANNResult<()> {
let temp_dir = tempdir()?;
Expand Down
144 changes: 144 additions & 0 deletions diskann-disk/src/build/chunking/continuation/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,150 @@ mod tests {
}
}

/// A tracker that returns Stop after `stop_after` Continue grants.
#[derive(Clone)]
struct StopAfterTracker {
count: std::sync::Arc<std::sync::Mutex<usize>>,
stop_after: usize,
}

impl ContinuationTrackerTrait for StopAfterTracker {
fn get_continuation_grant(&self) -> ContinuationGrant {
let mut count = self.count.lock().unwrap();
if *count >= self.stop_after {
ContinuationGrant::Stop
} else {
*count += 1;
ContinuationGrant::Continue
}
}
}

#[test]
fn test_process_while_resource_is_available_stops_early() {
let tracker = StopAfterTracker {
count: std::sync::Arc::new(std::sync::Mutex::new(0)),
stop_after: 3,
};
let items = vec![10, 20, 30, 40, 50];
let mut processed = Vec::new();

let result = process_while_resource_is_available(
|item| {
processed.push(item);
Ok::<(), TestError>(())
},
items.into_iter(),
Box::new(tracker),
);

assert!(result.is_ok());

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is redundant with the next one. If you match on result.unwrap() it will already panic if it's not ok.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

match result.unwrap() {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably better written as an if let Progress::processed(idx) = result.unwrap() { } else { panic!(); }. I'm surprised clippy didn't warn on this; it is usually not happy with single arm match statements.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Even removed panic!() to increase code coverage number.

Progress::Processed(idx) => {
assert_eq!(idx, 3); // stopped before processing item at index 3

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment correct? Processed(3) means processed to 3, or processed until 3? I really hope it is the former.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. it is the former. fixed.

assert_eq!(processed, vec![10, 20, 30]);
}
_ => panic!("Expected Processed"),
}
}

/// A tracker that yields once (with a tiny duration), then continues.
#[derive(Clone)]
struct YieldOnceThenContinueTracker {
yielded: std::sync::Arc<std::sync::Mutex<bool>>,
}

impl ContinuationTrackerTrait for YieldOnceThenContinueTracker {
fn get_continuation_grant(&self) -> ContinuationGrant {
let mut yielded = self.yielded.lock().unwrap();
if !*yielded {
*yielded = true;
ContinuationGrant::Yield(std::time::Duration::ZERO)
} else {
Comment thread
arrayka marked this conversation as resolved.
// After yielding once, always continue
ContinuationGrant::Continue
}
}
}

#[test]
fn test_process_while_resource_is_available_yield_then_continue() {
let tracker = YieldOnceThenContinueTracker {
yielded: std::sync::Arc::new(std::sync::Mutex::new(false)),
};
let items = vec![1, 2];
let mut processed = Vec::new();

let result = process_while_resource_is_available(
|item| {
processed.push(item);
Ok::<(), TestError>(())
},
items.into_iter(),
Box::new(tracker),
);

assert!(result.is_ok());

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comments from above apply here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

// After yielding, it should have continued and processed all items
match result.unwrap() {
Progress::Completed => assert_eq!(processed, vec![1, 2]),
_ => panic!("Expected Completed"),
}
}

#[test]
fn test_process_while_resource_is_available_action_error() {
let checker = Box::new(NaiveContinuationTracker::default());
let items = vec![1, 2, 3];

let result = process_while_resource_is_available(
|item| {
if item == 2 {
Err(TestError)
} else {
Ok(())
}
},
items.into_iter(),
checker,
);

assert!(result.is_err());
}

#[tokio::test]
async fn test_process_while_resource_is_available_async_stops_early() {
let tracker = StopAfterTracker {
count: std::sync::Arc::new(std::sync::Mutex::new(0)),
stop_after: 2,
};
let items = vec![1, 2, 3, 4, 5];
let processed = std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new()));

let result = process_while_resource_is_available_async(
|item| {
let processed = processed.clone();
async move {
processed.lock().await.push(item);
Ok::<(), TestError>(())
}
},
items.into_iter(),
Box::new(tracker),
)
.await;

assert!(result.is_ok());

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And same here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

match result.unwrap() {
Progress::Processed(idx) => {
assert_eq!(idx, 2);
let processed = processed.lock().await;
assert_eq!(*processed, vec![1, 2]);
}
_ => panic!("Expected Processed"),
}
}

#[tokio::test]
async fn test_process_while_resource_is_available_async_completes() {
let checker = Box::new(NaiveContinuationTracker::default());
Expand Down
36 changes: 29 additions & 7 deletions diskann-disk/src/search/pq/pq_scratch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,24 +58,22 @@ impl PQScratch {
})
}

/// Copy the first `dim` elements of `query` into `query_scratch`.
/// Copy `query` into `query_scratch`.
///
/// `query` must already be in full-precision `f32` representation; quantized
/// inputs (e.g. `MinMaxElement`) should be decoded via `VectorRepr::as_f32`
/// at the caller boundary before invoking this method.
///
/// Accepts oversized `query` (only the first `dim` elements are used) for
/// backwards compatibility with callers that hold alignment-padded buffers.
/// Returns `DimensionMismatchError` if `query.len() < query_scratch.len()`.
/// Returns `DimensionMismatchError` if `query.len() != query_scratch.len()`.
pub fn set(&mut self, query: &[f32]) -> ANNResult<()> {
let dim = self.query_scratch.len();
if query.len() < dim {
if query.len() != dim {
return Err(ANNError::log_dimension_mismatch_error(format!(
"PQScratch::set: expected query of length >= {dim}, got {}",
"PQScratch::set: expected query of length {dim}, got {}",
query.len()
)));
}
self.query_scratch.copy_from_slice(&query[..dim]);
self.query_scratch.copy_from_slice(query);
Ok(())
}

Expand Down Expand Up @@ -128,4 +126,28 @@ mod tests {
assert_eq!(pq_scratch.query_scratch[i], query[i]);
});
}

#[test]
fn test_pq_scratch_set_rejects_short_query() {
let dim = 16;
let mut pq_scratch = PQScratch::new(64, dim, 4, 256).unwrap();

// Query shorter than dim should fail
let short_query: Vec<f32> = (1..dim).map(|i| i as f32).collect(); // dim-1 elements
let err = pq_scratch.set(&short_query).unwrap_err();
assert_eq!(err.kind(), diskann::ANNErrorKind::DimensionMismatchError);
assert!(err.to_string().contains("expected query of length"));
}

#[test]
fn test_pq_scratch_set_rejects_oversized_query() {
let dim = 8;
let mut pq_scratch = PQScratch::new(64, dim, 4, 256).unwrap();

// Query longer than dim should fail
let long_query: Vec<f32> = (1..=dim + 10).map(|i| i as f32).collect();
let err = pq_scratch.set(&long_query).unwrap_err();
assert_eq!(err.kind(), diskann::ANNErrorKind::DimensionMismatchError);
assert!(err.to_string().contains("expected query of length"));
}
Comment thread
arrayka marked this conversation as resolved.
}
44 changes: 44 additions & 0 deletions diskann-disk/src/search/provider/disk_sector_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,4 +373,48 @@ mod disk_sector_graph_test {
let data = &graph;
assert_eq!(data.len(), 512);
}

#[test]
fn test_reconfigure_grows_buffer() {
let reader = AlignedFileReaderFactory::new(test_index_path())
.build()
.unwrap();
let mut graph = test_initialize_disk_sector_graph(2, 1, reader);
assert_eq!(graph.max_n_batch_sector_read, 4);

// Reconfigure to larger batch — buffer must grow beyond initial 512 bytes
graph.reconfigure(16).unwrap();
assert_eq!(graph.max_n_batch_sector_read, 16);
assert_eq!(graph.sectors_data.len(), 16 * 64);
}

#[test]
fn test_reconfigure_noop_for_smaller_size() {
let reader = AlignedFileReaderFactory::new(test_index_path())
.build()
.unwrap();
let mut graph = test_initialize_disk_sector_graph(2, 1, reader);
let original_len = graph.sectors_data.len();

// Reconfigure with same or smaller size should be a no-op
graph.reconfigure(4).unwrap();
assert_eq!(graph.max_n_batch_sector_read, 4);
assert_eq!(graph.sectors_data.len(), original_len);

graph.reconfigure(2).unwrap();
assert_eq!(graph.max_n_batch_sector_read, 4);
assert_eq!(graph.sectors_data.len(), original_len);
}

#[test]
fn test_new_disk_sector_graph_zero_block_size_defaults() {
let metadata = GraphMetadata::new(1000, 32, 500, 32, 2, 20, 50, 1024, 256);
// block_size = 0 should fall back to DEFAULT_DISK_SECTOR_LEN regardless of version
let header = GraphHeader::new(metadata, 0, GraphLayoutVersion::new(1, 0));
let reader = AlignedFileReaderFactory::new(test_index_path())
.build()
.unwrap();
let graph = DiskSectorGraph::new(reader, &header, 2).unwrap();
assert_eq!(graph.block_size, DEFAULT_DISK_SECTOR_LEN);
}
}
36 changes: 36 additions & 0 deletions diskann-disk/src/storage/quant/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -598,4 +598,40 @@ mod generator_tests {

Ok(())
}

#[test]
fn test_validate_params_missing_compressed_file() -> ANNResult<()> {
let storage_provider = VirtualStorageProvider::new_memory();
storage_provider
.filesystem()
.create_dir("/test_data")
.expect("Could not create test directory");

let data_path = "/test_data/data.bin";
let compressed_path = "/test_data/compressed.bin";
let num_points = 100;
let dim = 8;
let output_dim = 4u32;

// Create source data
let data = create_test_data(num_points, dim);
let view = MatrixView::try_from(data.as_slice(), num_points, dim).unwrap();
write_bin(view, &mut storage_provider.create_for_write(data_path)?)?;

// Don't create compressed file but set offset > 0
let context = GeneratorContext::new(10, compressed_path.to_string());
let generator = QuantDataGenerator::<f32, DummyCompressor>::new(
data_path.to_string(),
context,
&output_dim,
)
.unwrap();

let err = generator
.validate_params(num_points, &storage_provider)
.unwrap_err();
assert_eq!(err.kind(), diskann::ANNErrorKind::FileNotFoundError);
assert!(err.to_string().contains("expected compressed file"));
Ok(())
}
}
Loading
Loading