Skip to content

Commit

Permalink
fix: handle null vectors in flat search (#3422)
Browse files Browse the repository at this point in the history
  • Loading branch information
wjones127 authored Jan 28, 2025
1 parent 7c34f14 commit 7aa7d94
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 8 deletions.
8 changes: 4 additions & 4 deletions rust/lance-index/src/vector/flat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
use std::sync::Arc;

use arrow::array::AsArray;
use arrow::{array::AsArray, buffer::NullBuffer};
use arrow_array::{make_array, Array, ArrayRef, Float32Array, RecordBatch};
use arrow_schema::{DataType, Field as ArrowField};
use lance_arrow::*;
Expand Down Expand Up @@ -44,9 +44,9 @@ pub async fn compute_distance(
.clone();

let validity_buffer = if let Some(rowids) = batch.column_by_name(ROW_ID) {
rowids.nulls().map(|nulls| nulls.buffer().clone())
NullBuffer::union(rowids.nulls(), vectors.nulls())
} else {
None
vectors.nulls().cloned()
};

tokio::task::spawn_blocking(move || {
Expand All @@ -56,7 +56,7 @@ pub async fn compute_distance(
let vectors = vectors
.into_data()
.into_builder()
.null_bit_buffer(validity_buffer)
.null_bit_buffer(validity_buffer.map(|b| b.buffer().clone()))
.build()
.map(make_array)?;
let distances = match vectors.data_type() {
Expand Down
3 changes: 3 additions & 0 deletions rust/lance-index/src/vector/pq/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ pub(super) fn compute_pq_distance(
num_sub_vectors: usize,
code: &[u8],
) -> Vec<f32> {
if code.is_empty() {
return Vec::new();
}
if num_bits == 4 {
return compute_pq_distance_4bit(distance_table, num_sub_vectors, code);
}
Expand Down
78 changes: 74 additions & 4 deletions rust/lance/src/index/vector/ivf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1742,14 +1742,15 @@ mod tests {

use arrow_array::types::UInt64Type;
use arrow_array::{
make_array, Float32Array, RecordBatchIterator, RecordBatchReader, UInt64Array,
make_array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator,
RecordBatchReader, UInt64Array,
};
use arrow_buffer::{BooleanBuffer, NullBuffer};
use arrow_schema::Field;
use arrow_schema::{DataType, Field, Schema};
use itertools::Itertools;
use lance_core::utils::address::RowAddress;
use lance_core::ROW_ID;
use lance_datagen::{array, gen, Dimension, RowCount};
use lance_datagen::{array, gen, ArrayGeneratorExt, Dimension, RowCount};
use lance_index::vector::sq::builder::SQBuildParams;
use lance_linalg::distance::l2_distance_batch;
use lance_testing::datagen::{
Expand All @@ -1760,7 +1761,7 @@ mod tests {
use rstest::rstest;
use tempfile::tempdir;

use crate::dataset::InsertBuilder;
use crate::dataset::{InsertBuilder, WriteMode, WriteParams};
use crate::index::prefilter::DatasetPreFilter;
use crate::index::vector::IndexFileVersion;
use crate::index::vector_index_details;
Expand Down Expand Up @@ -2300,6 +2301,75 @@ mod tests {
assert_eq!(results["vec"].logical_null_count(), 0);
}

#[tokio::test]
async fn test_index_lifecycle_nulls() {
// Generate random data with nulls
let nrows = 2_000;
let dims = 32;
let data = gen()
.col(
"vec",
array::rand_vec::<Float32Type>(Dimension::from(dims as u32)).with_random_nulls(0.5),
)
.into_batch_rows(RowCount::from(nrows))
.unwrap();
let num_non_null = data["vec"].len() - data["vec"].logical_null_count();

let mut dataset = InsertBuilder::new("memory://")
.execute(vec![data])
.await
.unwrap();

// Create index
let index_params = VectorIndexParams::with_ivf_pq_params(
MetricType::L2,
IvfBuildParams::new(2),
PQBuildParams::new(2, 8),
);
dataset
.create_index(&["vec"], IndexType::Vector, None, &index_params, false)
.await
.unwrap();

// Check that the index is working
async fn check_index(dataset: &Dataset, num_non_null: usize, dims: usize) {
let query = vec![0.0; dims].into_iter().collect::<Float32Array>();
let results = dataset
.scan()
.nearest("vec", &query, 2_000)
.unwrap()
.nprobs(2)
.try_into_batch()
.await
.unwrap();
assert_eq!(results.num_rows(), num_non_null);
}
check_index(&dataset, num_non_null, dims).await;

// Append more data
let data = gen()
.col(
"vec",
array::rand_vec::<Float32Type>(Dimension::from(dims as u32)).with_random_nulls(0.5),
)
.into_batch_rows(RowCount::from(500))
.unwrap();
let num_non_null = data["vec"].len() - data["vec"].logical_null_count() + num_non_null;
let mut dataset = InsertBuilder::new(Arc::new(dataset))
.with_params(&WriteParams {
mode: WriteMode::Append,
..Default::default()
})
.execute(vec![data])
.await
.unwrap();
check_index(&dataset, num_non_null, dims).await;

// Optimize the index
dataset.optimize_indices(&Default::default()).await.unwrap();
check_index(&dataset, num_non_null, dims).await;
}

#[tokio::test]
async fn test_create_ivf_pq_cosine() {
let test_dir = tempdir().unwrap();
Expand Down

0 comments on commit 7aa7d94

Please sign in to comment.