Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,12 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
tsl::robin_map<uint32_t, std::vector<uint32_t>> _real_to_dummy_map;
std::unordered_map<std::string, LabelT> _label_map;

// Validate if data type is correct by check each data's vector/neighbor section.
// disk_nnodes, total node count in index, used to validate neighbor index of each node.
// contains_disk_pq_file, flag if indices contains disk pq file(suffix : _disk.index_pq_pivots.bin);
// If it's true, max_node_len use uint8 to calculate data size instead.(check aux_utils.cpp:1399)
bool validate_vector_data_type(uint64_t disk_nnodes, bool contains_disk_pq_file);

#ifdef EXEC_ENV_OLS
// Set to a larger value than the actual header to accommodate
// any additions we make to the header. This is an outer limit
Expand Down
95 changes: 93 additions & 2 deletions src/pq_flash_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1209,7 +1209,13 @@ int PQFlashIndex<T, LabelT>::load_from_separate_paths(uint32_t num_threads, cons
this->_max_nthreads = num_threads;

#endif

// Validate data type (called once for both EXEC_ENV_OLS and non-OLS paths)
if (!validate_vector_data_type(disk_nnodes, _use_disk_index_pq))
{
throw diskann::ANNException("Data type validation failed. Please ensure --data_type matches "
"the type used when building the index.",
-1, __FUNCSIG__, __FILE__, __LINE__);
}
#ifdef EXEC_ENV_OLS
if (files.fileExists(medoids_file))
{
Expand Down Expand Up @@ -1768,7 +1774,7 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
{
full_retset.push_back(Neighbor(frontier_nhood.first, cur_expanded_dist));
}

uint32_t *node_nbrs = (node_buf + 1);
// compute node_nbrs <-> query dist in PQ space
cpu_timer.reset();
Expand Down Expand Up @@ -1971,6 +1977,91 @@ template <typename T, typename LabelT> char *PQFlashIndex<T, LabelT>::getHeaderB
}
#endif

template <typename T, typename LabelT>
bool PQFlashIndex<T, LabelT>::validate_vector_data_type(uint64_t disk_nnodes, bool contains_disk_pq_file)
{
(void)contains_disk_pq_file; // Suppress unused parameter warning

// Use _disk_bytes_per_point which is already set correctly:
// - For disk PQ: _disk_pq_n_chunks * sizeof(uint8_t)
// - For regular: _data_dim * sizeof(T)
size_t vector_length = _disk_bytes_per_point;
if (vector_length + sizeof(uint32_t) >= _max_node_len)
{
diskann::cerr << "Vector length : " << vector_length << " and neighbor count size : " << sizeof(uint32_t)
<< ", expected less than max node length : " << _max_node_len << std::endl;
diskann::cerr << "Please check if wrong data type with larger size is "
"specified, like use float type to load byte index!"
<< std::endl;
return false;
}

// Borrow thread data for the read
ScratchStoreManager<SSDThreadData<T>> manager(this->_thread_data);
auto this_thread_data = manager.scratch_space();
IOContext &ctx = this_thread_data->ctx;

// Calculate the number of sectors needed per node
uint64_t num_sectors_per_node = _nnodes_per_sector > 0 ? 1 : DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN);

// Allocate sector-aligned buffer (required for direct I/O)
char *buf = nullptr;
alloc_aligned((void **)&buf, num_sectors_per_node * defaults::SECTOR_LEN, defaults::SECTOR_LEN);

// Read sector(s) containing node 0
uint64_t node_sector = get_node_sector(0);
AlignedRead read_request(node_sector * defaults::SECTOR_LEN, num_sectors_per_node * defaults::SECTOR_LEN, buf);
std::vector<AlignedRead> read_requests;
read_requests.emplace_back(read_request);
reader->read(read_requests, ctx);

#if defined(_WINDOWS) && defined(USE_BING_INFRA)
if ((*ctx.m_pRequestsStatus)[0] != IOContext::READ_SUCCESS)
{
aligned_free(buf);
diskann::cerr << "Read disk index file " << _disk_index_file << " failed, can't validate data type!"
<< std::endl;
return false;
}
#endif

// Use offset_to_node to get correct position within the sector buffer for node 0
char *first_node = offset_to_node(buf, 0);

// Use offset_to_node_nhood which correctly uses _disk_bytes_per_point
uint32_t *neighbors = offset_to_node_nhood(first_node);
uint32_t neighbor_count = *neighbors;

// Calculate max degree based on the assumed vector length
// max_node_len = vector_length + sizeof(uint32_t) + max_degree * sizeof(uint32_t)
// So: max_degree = (max_node_len - vector_length - sizeof(uint32_t)) / sizeof(uint32_t)
uint32_t max_degree = static_cast<uint32_t>((_max_node_len - vector_length - sizeof(uint32_t)) / sizeof(uint32_t));

if (neighbor_count > max_degree)
{
aligned_free(buf);
diskann::cerr << "Calculated max neighbor count : " << max_degree
<< " and first node neighbor count : " << neighbor_count << ", load data type is not correct!"
<< std::endl;
return false;
}

for (uint32_t i = 1; i <= neighbor_count; ++i)
{
if (neighbors[i] >= disk_nnodes)
{
aligned_free(buf);
diskann::cerr << "Neighbor[" << i - 1 << "], index : " << neighbors[i]
<< ", greater than total node count : " << disk_nnodes << ", load data type is not correct!"
<< std::endl;
return false;
}
}

aligned_free(buf);
return true;
}

template <typename T, typename LabelT>
std::vector<std::uint8_t> PQFlashIndex<T, LabelT>::get_pq_vector(std::uint64_t vid)
{
Expand Down
Loading