10000 Migrate sparse knn and distances code from raft by benfred · Pull Request #457 · rapidsai/cuvs · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Migrate sparse knn and distances code from raft #457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ if(BUILD_SHARED_LIBS)
src/distance/detail/fused_distance_nn.cu
src/distance/distance.cu
src/distance/pairwise_distance.cu
src/distance/sparse_distance.cu
src/neighbors/brute_force.cu
src/neighbors/cagra_build_float.cu
src/neighbors/cagra_build_half.cu
Expand Down Expand Up @@ -449,6 +450,7 @@ if(BUILD_SHARED_LIBS)
src/neighbors/refine/detail/refine_host_int8_t_float.cpp
src/neighbors/refine/detail/refine_host_uint8_t_float.cpp
src/neighbors/sample_filter.cu
src/neighbors/sparse_brute_force.cu
src/neighbors/vamana_build_float.cu
src/neighbors/vamana_build_uint8.cu
src/neighbors/vamana_build_int8.cu
Expand Down
81 changes: 81 additions & 0 deletions cpp/include/cuvs/distance/distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <cstdint>
#include <cuda_fp16.h>
#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/resources.hpp>

Expand Down Expand Up @@ -331,6 +332,86 @@ void pairwise_distance(
cuvs::distance::DistanceType metric,
float metric_arg = 2.0f);

/**
* @brief Compute sparse pairwise distances between x and y, using the provided
* input configuration and distance function.
*
* @code{.cpp}
* #include <raft/core/device_resources.hpp>
* #include <raft/core/device_csr_matrix.hpp>
* #include <raft/core/device_mdspan.hpp>
*
* int x_n_rows = 100000;
* int y_n_rows = 50000;
* int n_cols = 10000;
*
* raft::device_resources handle;
* auto x = raft::make_device_csr_matrix<float>(handle, x_n_rows, n_cols);
* auto y = raft::make_device_csr_matrix<float>(handle, y_n_rows, n_cols);
*
* ...
* // populate data
* ...
*
* auto out = raft::make_device_matrix<float>(handle, x_nrows, y_nrows);
* auto metric = cuvs::distance::DistanceType::L2Expanded;
* raft::sparse::distance::pairwise_distance(handle, x.view(), y.view(), out, metric);
* @endcode
*
* @param[in] handle raft::resources
* @param[in] x raft::device_csr_matrix_view
* @param[in] y raft::device_csr_matrix_view
* @param[out] dist raft::device_matrix_view dense matrix
* @param[in] metric distance metric to use
* @param[in] metric_arg metric argument (used for Minkowski distance)
*/
void pairwise_distance(raft::resources const& handle,
raft::device_csr_matrix_view<const float, int, int, int> x,
raft::device_csr_matrix_view<const float, int, int, int> y,
raft::device_matrix_view<float, int, raft::row_major> dist,
cuvs::distance::DistanceType metric,
float metric_arg = 2.0f);

/**
* @brief Compute sparse pairwise distances between x and y, using the provided
* input configuration and distance function.
*
* @code{.cpp}
* #include <raft/core/device_resources.hpp>
* #include <raft/core/device_csr_matrix.hpp>
* #include <raft/core/device_mdspan.hpp>
*
* int x_n_rows = 100000;
* int y_n_rows = 50000;
* int n_cols = 10000;
*
* raft::device_resources handle;
* auto x = raft::make_device_csr_matrix<double>(handle, x_n_rows, n_cols);
* auto y = raft::make_device_csr_matrix<double>(handle, y_n_rows, n_cols);
*
* ...
* // populate data
* ...
*
* auto out = raft::make_device_matrix<double>(handle, x_nrows, y_nrows);
* auto metric = cuvs::distance::DistanceType::L2Expanded;
* raft::sparse::distance::pairwise_distance(handle, x.view(), y.view(), out, metric);
* @endcode
*
* @param[in] handle raft::resources
* @param[in] x raft::device_csr_matrix_view
* @param[in] y raft::device_csr_matrix_view
* @param[out] dist raft::device_matrix_view dense matrix
* @param[in] metric distance metric to use
* @param[in] metric_arg metric argument (used for Minkowski distance)
*/
void pairwise_distance(raft::resources const& handle,
raft::device_csr_matrix_view<const double, int, int, int> x,
raft::device_csr_matrix_view<const double, int, int, int> y,
raft::device_matrix_view<double, int, raft::row_major> dist,
cuvs::distance::DistanceType metric,
float metric_arg = 2.0f);

/** @} */ // end group pairwise_distance_runtime

}; // namespace cuvs::distance
104 changes: 104 additions & 0 deletions cpp/include/cuvs/neighbors/brute_force.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "common.hpp"
#include <cuvs/neighbors/common.hpp>
#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
Expand Down Expand Up @@ -375,4 +376,107 @@ void search(raft::resources const& handle,
* @}
*/

/**
* @defgroup sparse_bruteforce_cpp_index Sparse Brute Force index
* @{
*/
/**
* @brief Sparse Brute Force index.
*
* @tparam T Data element type
* @tparam IdxT Index element type
*/
template <typename T, typename IdxT>
struct sparse_index {
public:
sparse_index(const sparse_index&) = delete;
sparse_index(sparse_index&&) = default;
sparse_index& operator=(const sparse_index&) = delete;
sparse_index& operator=(sparse_index&&) = default;
~sparse_index() = default;

/** Construct a sparse brute force sparse_index from dataset */
sparse_index(raft::resources const& res,
raft::device_csr_matrix_view<const T, IdxT, IdxT, IdxT> dataset,
cuvs::distance::DistanceType metric,
T metric_arg);

/** Distance metric used for retrieval */
cuvs::distance::DistanceType metric() const noexcept { return metric_; }

/** Metric argument */
T metric_arg() const noexcept { return metric_arg_; }

raft::device_csr_matrix_view<const T, IdxT, IdxT, IdxT> dataset() const noexcept
{
return dataset_;
}

private:
raft::device_csr_matrix_view<const T, IdxT, IdxT, IdxT> dataset_;
cuvs::distance::DistanceType metric_;
T metric_arg_;
};
/**
* @}
*/

/**
* @defgroup sparse_bruteforce_cpp_index_build Sparse Brute Force index build
* @{
*/

/*
* @brief Build the Sparse index from the dataset
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
* // create and fill the index from a CSR dataset
* auto index = brute_force::build(handle, dataset B114 , metric);
* @endcode
*
* @param[in] handle
* @param[in] dataset A sparse CSR matrix in device memory to search against
* @param[in] metric cuvs::distance::DistanceType
* @param[in] metric_arg metric argument
*
* @return the constructed Sparse brute-force index
*/
auto build(raft::resources const& handle,
raft::device_csr_matrix_view<const float, int, int, int> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::sparse_index<float, int>;
/**
* @}
*/

/**
* @defgroup sparse_bruteforce_cpp_index_search Sparse Brute Force index search
* @{
*/
struct sparse_search_params {
int batch_size_index = 2 << 14;
int batch_size_query = 2 << 14;
};

/*
* @brief Search the sparse bruteforce index for nearest neighbors
*
* @param[in] handle
* @param[in] index Sparse brute-force constructed index
* @param[in] queries a sparse CSR matrix on the device to query
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
*/
void search(raft::resources const& handle,
const sparse_search_params& params,
const sparse_index<float, int>& index,
raft::device_csr_matrix_view<const float, int, int, int> dataset,
raft::device_matrix_view<int, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances);
/**
* @}
*/
} // namespace cuvs::neighbors::brute_force
Loading
Loading
0