8000 Facebook sync (2020-03-10) by beauby · Pull Request #1136 · facebookresearch/faiss · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Facebook sync (2020-03-10) #1136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
371 changes: 318 additions & 53 deletions Clustering.cpp

Large diffs are not rendered by default.

50 changes: 39 additions & 11 deletions Clustering.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,31 @@ struct ClusteringParameters {
bool verbose;
bool spherical; ///< do we want normalized centroids?
bool int_centroids; ///< round centroids coordinates to integer
bool update_index; ///< update index after each iteration?
bool update_index; ///< re-train index after each iteration?
bool frozen_centroids; ///< use the centroids provided as input and do not change them during iterations

int min_points_per_centroid; ///< otherwise you get a warning
int max_points_per_centroid; ///< to limit size of dataset

int seed; ///< seed for the random number generator

size_t decode_block_size; /// < how many vectors at a time to decode

/// sets reasonable defaults
ClusteringParameters ();
};


/** clustering based on assignment - centroid update iterations
struct ClusteringIterationStats {
float obj; /// objective values (sum of distances reported by index)
double time; /// seconds for iteration
double time_search; /// seconds for just search
double imbalance_factor; /// imbalance factor of iteration
int nsplit; /// number of cluster splits
};


/** K-means clustering based on assignment - centroid update iterations
*
* The clustering is based on an Index object that assigns training
* points to the centroids. Therefore, at each iteration the centroids
Expand All @@ -50,27 +61,44 @@ struct ClusteringParameters {
* centroids table it is not empty on input, it is also used for
* initialization.
*
* To do several clusterings, just call train() several times on
* different training sets, clearing the centroid table in between.
*/
struct Clustering: ClusteringParameters {
typedef Index::idx_t idx_t;
size_t d; ///< dimension of the vectors
size_t k; ///< nb of centroids

/// centroids (k * d)
/** centroids (k * d)
* if centroids are set on input to train, they will be used as initialization
*/
std::vector<float> centroids;

/// objective values (sum of distances reported by index) over
/// iterations
std::vector<float> obj;
/// stats at every iteration of clustering
std::vector<ClusteringIterationStats> iteration_stats;

/// the only mandatory parameters are k and d
Clustering (int d, int k);
Clustering (int d, int k, const ClusteringParameters &cp);

/// Index is used during the assignment stage
virtual void train (idx_t n, const float * x, faiss::Index & index);
/** run k-means training
*
* @param x training vectors, size n * d
* @param index index used for assignment
* @param x_weights weight associated to each vector: NULL or size n
*/
virtual void train (idx_t n, const float * x, faiss::Index & index,
const float *x_weights = nullptr);


/** run with encoded vectors
*
* win addition to train()'s parameters takes a codec as parameter
* to decode the input vectors.
*
* @param codec codec used to decode the vectors (nullptr =
* vectors are in fact floats) *
*/
void train_encoded (idx_t nx, const uint8_t *x_in,
const Index * codec, Index & index,
const float *weights = nullptr);

/// Post-process the centroids after each centroid update.
/// includes optional L2 normalization and nearest integer rounding
Expand Down
267 changes: 267 additions & 0 deletions DirectMap.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

// -*- c++ -*-

#include <faiss/DirectMap.h>

#include <cstdio>
#include <cassert>

#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/AuxIndexStructures.h>

namespace faiss {

DirectMap::DirectMap(): type(NoMap)
{}

void DirectMap::set_type (Type new_type, const InvertedLists *invlists, size_t ntotal) {

FAISS_THROW_IF_NOT (new_type == NoMap || new_type == Array ||
new_type == Hashtable);

if (new_type == type) {
// nothing to do
return;
}

array.clear ();
hashtable.clear ();
type = new_type;

if (new_type == NoMap) {
return;
} else if (new_type == Array) {
array.resize (ntotal, -1);
} else if (new_type == Hashtable) {
hashtable.reserve (ntotal);
}

for (size_t key = 0; key < invlists->nlist; key++) {
size_t list_size = invlists->list_size (key);
InvertedLists::ScopedIds idlist (invlists, key);

if (new_type == Array) {
for (long ofs = 0; ofs < list_size; ofs++) {
FAISS_THROW_IF_NOT_MSG (
0 <= idlist [ofs] && idlist[ofs] < ntotal,
"direct map supported only for seuquential ids");
array [idlist [ofs]] = lo_build(key, ofs);
}
} else if (new_type == Hashtable) {
for (long ofs = 0; ofs < list_size; ofs++) {
hashtable [idlist [ofs]] = lo_build(key, ofs);
}
}
}
}

void DirectMap::clear()
{
array.clear ();
hashtable.clear ();
}


DirectMap::idx_t DirectMap::get (idx_t key) const
{
if (type == Array) {
FAISS_THROW_IF_NOT_MSG (
key >= 0 && key < array.size(), "invalid key"
);
idx_t lo = array[key];
FAISS_THROW_IF_NOT_MSG(lo >= 0, "-1 entry in direct_map");
return lo;
} else if (type == Hashtable) {
auto res = hashtable.find (key);
FAISS_THROW_IF_NOT_MSG (res != hashtable.end(), "key not found");
return res->second;
} else {
FAISS_THROW_MSG ("direct map not initialized");
}
}



void DirectMap::add_single_id (idx_t id, idx_t list_no, size_t offset)
{
if (type == NoMap) return;

if (type == Array) {
assert (id == array.size());
if (list_no >= 0) {
array.push_back (lo_build (list_no, offset));
} else {
array.push_back (-1);
}
} else if (type == Hashtable) {
if (list_no >= 0) {
hashtable[id] = lo_build (list_no, offset);
}
}

}

void DirectMap::check_can_add (const idx_t *ids) {
if (type == Array && ids) {
FAISS_THROW_MSG ("cannot have array direct map and add with ids");
}
}

/********************* DirectMapAdd implementation */


DirectMapAdd::DirectMapAdd (DirectMap &direct_map, size_t n, const idx_t *xids):
direct_map(direct_map), type(direct_map.type), n(n), xids(xids)
{
if (type == DirectMap::Array) {
FAISS_THROW_IF_NOT (xids == nullptr);
ntotal = direct_map.array.size();
direct_map.array.resize (ntotal + n, -1);
} else if (type == DirectMap::Hashtable) {
// can't parallel update hashtable so use temp array
all_ofs.resize (n, -1);
}
}


void DirectMapAdd::add (size_t i, idx_t list_no, size_t ofs)
{
if (type == DirectMap::Array) {
direct_map.array [ntotal + i] = lo_build (list_no, ofs);
} else if (type == DirectMap::Hashtable) {
all_ofs [i] = lo_build (list_no, ofs);
}
}

DirectMapAdd::~DirectMapAdd ()
{
if (type == DirectMap::Hashtable) {
for (int i = 0; i < n; i++) {
idx_t id = xids ? xids[i] : ntotal + i;
direct_map.hashtable [id] = all_ofs [i];
}
}
}

/********************************************************/

using ScopedCodes = InvertedLists::ScopedCodes;
using ScopedIds = InvertedLists::ScopedIds;


size_t DirectMap::remove_ids(const IDSelector& sel, InvertedLists *invlists)
{
size_t nlist = invlists->nlist;
std::vector<idx_t> toremove(nlist);

size_t nremove = 0;

if (type == NoMap) {
// exhaustive scan of IVF
#pragma omp parallel for
for (idx_t i = 0; i < nlist; i++) {
idx_t l0 = invlists->list_size (i), l = l0, j = 0;
ScopedIds idsi (invlists, i);
while (j < l) {
if (sel.is_member (idsi[j])) {
l--;
invlists->update_entry (
i, j,
invlists->get_single_id (i, l),
ScopedCodes (invlists, i, l).get()
);
} else {
j++;
}
}
toremove[i] = l0 - l;
}
// this will not run well in parallel on ondisk because of
// possible shrinks
for (idx_t i = 0; i < nlist; i++) {
if (toremove[i] > 0) {
nremove += toremove[i];
invlists->resize(i, invlists->list_size(i) - toremove[i]);
}
}
} else if (type == Hashtable) {
const IDSelectorArray *sela =
dynamic_cast<const IDSelectorArray*>(&sel);
FAISS_THROW_IF_NOT_MSG (
sela,
"remove with hashtable works only with IDSelectorArray"
);

for (idx_t i = 0; i < sela->n; i++) {
idx_t id = sela->ids[i];
auto res = hashtable.find (id);
if (res != hashtable.end()) {
size_t list_no = lo_listno (res->second);
size_t offset = lo_offset (res->second);
idx_t last = invlists->list_size (list_no) - 1;< C6C3 /td>
hashtable.erase (res);
if (offset < last) {
idx_t last_id = invlists->get_single_id (list_no, last);
invlists->update_entry (
list_no, offset,
last_id,
ScopedCodes (invlists, list_no, last).get()
);
// update hash entry for last element
hashtable [last_id] = list_no << 32 | offset;
}
invlists->resize(list_no, last);
nremove++;
}
}

} else {
FAISS_THROW_MSG("remove not supported with this direct_map format");
}
return nremove;
}

void DirectMap::update_codes (InvertedLists *invlists,
int n, const idx_t *ids,
const idx_t *assign,
const uint8_t *codes)
{
FAISS_THROW_IF_NOT (type == Array);

size_t code_size = invlists->code_size;

for (size_t i = 0; i < n; i++) {
idx_t id = ids[i];
FAISS_THROW_IF_NOT_MSG (0 <= id && id < array.size(),
"id to update out of range");
{ // remove old one
idx_t dm = array [id];
int64_t ofs = lo_offset (dm);
int64_t il = lo_listno (dm);
size_t l = invlists->list_size (il);
if (ofs != l - 1) { // move l - 1 to ofs
int64_t id2 = invlists->get_single_id (il, l - 1);
array[id2] = lo_build (il, ofs);
invlists->update_entry (il, ofs, id2,
invlists->get_single_code (il, l - 1));
}
invlists->resize (il, l - 1);
}
{ // insert new one
int64_t il = assign[i];
size_t l = invlists->list_size (il);
idx_t dm = lo_build (il, l);
array [id] = dm;
invlists->add_entry (il, id, codes + i * code_size);
}
}
}


}
Loading
0