Implemented Papers • Requirements and Installation • Getting Started • Benchmarks • FAQ • Acknowledgements
- 2022/10/04 kNN-models is publicly available
- 2022/11/20 Supported retrieving with Elasticsearch, please refer examples/es_knnmt folder for more details
kNN-models is a k-nearest neighbor augmented sequence modeling toolkit implemented based on Fairseq. It enhances the pre-trained neural sequence to sequence model by retrieving from the external memory without expensive retraining.
Main features:
- Fast and memory efficient (please see benchmarks for details)
- Provide reference implementation of various k-nearest neighbor augmented sequence modeling papers (please see Implemented-papers for details)
- Compatible with most of the pre-trained models in Fairseq (although only the transformer model has been well tested yet, we plan to conduct experiments with other models in the future)
- Support similarity search with Faiss and Elasticsearch
- The Faiss index can be placed on a GPU that is different from the one occupied by the model and sharded between multiple GPUs to avoid out of memory
- The module which produces the intermediate hidden state to serve as datastore keys can be configured through command line arguments to adapt to the user's needs (it is the last layer in the decoder by default, please see the BaseKnnConfig for details)
- Flexible configuration based on Hydra
The repository contains the reference implementation of following papers (sorted by publication date):
- What Knowledge Is Needed? Towards Explainable Memory for kNN-MT Domain Adaptation (Arxiv)
- Simple and Scalable Nearest Neighbor Machine Translation (ICLR 2023)
- Towards Robust k-Nearest-Neighbor Machine Translation (EMNLP 2022)
- Efficient Cluster-Based k-Nearest-Neighbor Machine Translation (ACL 2022)
- Efficient Nearest Neighbor Language Models (EMNLP 2021)
- Learning Kernel-Smoothed Machine Translation with Retrieved Examples (EMNLP 2021)
- Adaptive Nearest Neighbor Machine Translation (ACL 2021)
- Nearest Neighbor Machine Translation (ICLR 2021)
- Generalization through Memorization: Nearest Neighbor Language Models (ICLR 2020)
The detailed READMEs about how to reproduce them with kNN-models can be found in the examples folder.
The repository is developed and tested on Python 3.10, PyTorch 1.10.0, Fairseq 0.12.1, Faiss-gpu 1.7.2, and Elasticsearch 8.5.0. We recommend users keep the versions of these packages the same as ours to alleviate the compatibility issues, even though other versions may also work.
To install kNN-models and develop locally:
git clone https://github.com/cordercorder/knn-models.git
cd knn-models
pip install -e ./
Note that pip install -e ./
will check the packages in the Python environment to resolve the dependencies specified
in requirements.txt
. However, faiss-gpu installed
through conda
can not be identified by pip
, which will result in the redundant
Faiss installation from PIP source. If you are pretty sure that
all the packages required by this repository are installed well, you can delete faiss-gpu>=1.7.2
in requirements.txt
and run python setup.py develop
to install kNN-models instead.
If you want to use Elasticsearch
for k-nearest neighbor search rather than Faiss, please set up
Elasticsearch following the
instructions first and then install
elasticsearch-py through pip install elasticsearch
.
We try to make the implementation independent of the model architecture during developing this repository. Consequently, we extend the task in Fairseq with the ability to perform similarity search. As the task can be combined with different model architectures, we can enhance various pre-trained models with the external memory without modifying the official code of Fairseq. For example, the kNN-MT can be implemented with just a few lines of code like the following:
from functools import partial
from dataclasses import dataclass
from fairseq.tasks.translation import (
TranslationTask,
TranslationConfig,
)
from fairseq.tasks import register_task
from fairseq.dataclass import FairseqDataclass
from knn_models.dataclass import KnnConfig
from knn_models.hook_utils import ForwardHook
from knn_models.knn_utils import (
KnnSearch,
get_captured_module,
get_normalized_probs,
)
@dataclass
class TranslationKnnConfig(TranslationConfig):
"""config for nearest neighbor machine translation"""
knn_config: KnnConfig = KnnConfig()
@register_task("translation_knn", dataclass=TranslationKnnConfig)
class TranslationKnnTask(TranslationTask):
"""task for nearest neighbor machine translation"""
def __init__(self, cfg: TranslationKnnConfig, src_dict, tgt_dict):
super().__init__(cfg, src_dict, tgt_dict)
self.knn_search = KnnSearch(cfg.knn_config)
self.forward_hook = ForwardHook()
def build_model(self, cfg: FairseqDataclass, from_checkpoint=False):
model = super().build_model(cfg, from_checkpoint)
assert hasattr(model, "decoder"), \
"TranslationKnnTask only supports the model with decoder! " \
f"There is no decoder in {model.__class__.__name__}."
# collect outputs from the specified module in decoder as the datastore keys
captured_module_name = self.cfg.knn_config.module_to_capture
captured_module = get_captured_module(model.decoder, captured_module_name)
captured_module.register_forward_hook(self.forward_hook.forward_hook_function)
# rewrite `get_normalized_probs` function to support kNN augmented NMT
model.get_normalized_probs = partial(get_normalized_probs, self, model)
return model
We measured the generation speed and GPU memory consumption during inference to evaluate the performance of kNN-models. We conducted experiments on kNN-MT and Adaptive kNN-MT considering that they are dominant approaches to enabling retrieval argumented MT.
Following the common practice, we used the multi-domain dataset
(Koehn & Knowles, 2017) which was re-split by
Aharoni & Goldberg (2020) for experiments and
the WMT’19 German-English news translation task winner model (Ng et al., 2019)
was adopted as the pre-trained NMT model. For kNN-MT,
we tuned the hyperparameters (num_neighbors
, lambda
, temperature
) on the validation sets
according to the BLEU score. The hyperparameters for Adaptive kNN-MT
were inherited from kNN-MT except for lambda
, which can be
inferred from the Meta-k-Network of Adaptive kNN-MT. We employed
beam search with a beam size of 5 and a length penalty of 1.0 during decoding. It is worth
noting that only one GPU was used throughout the benchmark experiments and the
Faiss index was placed on GPU to speed up the search operation.
The datastore size and the hyperparameters for each domain are presented below:
Medical | Law | IT | Koran | Subtitles | |
---|---|---|---|---|---|
datastore size | 6501418 | 18857646 | 3449918 | 519897 | 6209620 |
num_neighbors | 8 | 8 | 16 | 16 | 16 |
lambda | 0.7 | 0.7 | 0.6 | 0.7 | 0.5 |
temperature | 5 | 5 | 5 | 20 | 20 |
The BLEU score of the pre-trained NMT model (Base MT), kNN-MT, and Adaptive kNN-MT on the test sets for each domain are presented below:
Medical | Law | IT | Koran | Subtitles | |
---|---|---|---|---|---|
Base MT | 41.87 | 45.96 | 38.52 | 17.07 | 29.39 |
kNN-MT | 57.08 | 62.48 | 47.1 | 22.54 | 30.55 |
Adaptive kNN-MT | 58.17 | 63.32 | 48.33 | 22.03 | 30.45 |
As the generation speed usually varies between different runs and is highly dependent on the hardware environment, we performed each experiment 5 times and reported the mean and standard deviation of the generation speed on two different servers respectively.
The generation speed (token/s) of kNN-models on a server with 8 NVIDIA Tesla P100 GPUs (16GB), 2 Intel Xeon Gold 6240 CPUs, and 256 GB of RAM is presented below (as there are sentences with more than 400 tokens in the test sets of medical and law domains, the generation speed is not available in the case of batch size set to 400):
Batch Size | Medical | Law | IT | Koran | Subtitles | |
---|---|---|---|---|---|---|
400 tokens | Base MT | N/A | N/A | 593.67±12.92 | 577.60±14.76 | 1005.69±44.67 |
kNN-MT | N/A | N/A | 492.66±21.24 | 488.79±20.47 | 858.08±29.71 | |
Adaptive kNN-MT | N/A | N/A | 470.20±20.02 | 455.39±16.95 | 806.94±24.71 | |
800 tokens | Base MT | 761.39±29.74 | 705.84±7.99 | 869.02±36.63 | 830.49±34.10 | 1502.55±29.31 |
kNN-MT | 625.08±24.04 | 542.48±21.85 | 738.49±31.51 | 689.17±36.21 | 1240.48±21.99 | |
Adaptive kNN-MT | 591.90±16.39 | 521.86±12.26 | 710.79±17.69 | 642.82±20.04 | 1190.69±15.46 | |
1600 tokens | Base MT | 1033.93±30.34 | 1000.80±34.31 | 1195.03±41.52 | 1138.84±41.03 | 1859.79±10.62 |
kNN-MT | 829.28±22.33 | 743.36±23.23 | 993.22±22.14 | 960.69±27.82 | 1467.16±4.67 | |
Adaptive kNN-MT | 812.92±13.07 | 715.14±18.86 | 924.22±22.44 | 903.87±16.43 | 1408.14±16.42 | |
3200 tokens | Base MT | 1335.80±20.57 | 1294.52±15.47 | 1445.16±20.55 | 1497.09±16.30 | 2047.57±19.40 |
kNN-MT | 1046.16±16.05 | 940.59±9.40 | 1197.04±18.48 | 1247.45±17.36 | 1586.45±10.99 | |
Adaptive kNN-MT | 1036.07±3.97 | 917.63±10.08 | 1189.73±5.70 | 1203.48±9.22 | 1577.00±12.18 | |
6400 tokens | Base MT | 1563.36±11.48 | 1522.87±11.01 | 1613.63±17.39 | 1716.00±11.16 | 2126.56±19.66 |
kNN-MT | 1226.55±3.98 | 1072.35±5.72 | 1323.60±14.69 | 1447.19±13.10 | 1660.31±15.97 | |
Adaptive kNN-MT | 1193.37±13.58 | 1043.77±6.62 | 1293.78±11.54 | 1408.91±7.27 | 1648.06±17.63 | |
12800 tokens | Base MT | 1675.49±9.45 | 1633.76±9.67 | 1647.95±12.20 | 1803.01±10.18 | 2197.24±13.67 |
kNN-MT | 1300.68±6.27 | 1140.59±3.88 | 1334.90±2.23 | 1532.65±8.40 | 1694.99±7.50 | |
Adaptive kNN-MT | 1275.62±10.28 | 1125.35±5.66 | 1323.47±9.31 | 1500.19±10.48 | 1699.80±10.55 |
The generation speed (token/s) of kNN-models on a server with 8 NVIDIA GeForce GTX TITAN GPUs (24GB), 2 Intel Xeon E5-2680 CPUs, and 256 GB of RAM is presented below:
Batch Size | Medical | Law | IT | Koran | Subtitles | |
---|---|---|---|---|---|---|
400 tokens | Base MT | N/A | N/A | 435.83±15.51 | 432.85±16.09 | 844.25±57.33 |
kNN-MT | N/A | N/A | 408.02±21.15 | 403.94±16.99 | 759.71±51.01 | |
Adaptive kNN-MT | N/A | N/A | 393.35±25.35 | 371.31±29.31 | 724.04±42.07 | |
800 tokens | Base MT | 634.81±15.64 | 588.01±14.00 | 743.54±42.92 | 682.80±19.63 | 1507.27±54.44 |
kNN-MT | 542.13±11.21 | 481.48±8.66 | 651.12±31.04 | 618.70±11.19 | 1261.36±44.09 | |
Adaptive kNN-MT | 526.43±33.34 | 436.25±21.67 | 633.04±29.44 | 556.48±35.99 | 1244.21±69.26 | |
1600 tokens | Base MT | 967.79±14.60 | 983.15±9.54 | 1110.93±25.45 | 1088.76±41.47 | 2182.40±74.34 |
kNN-MT | 761.56±33.66 | 726.35±25.67 | 1040.71±17.07 | 919.17±31.14 | 1664.39±55.27 | |
Adaptive kNN-MT | 745.29±21.61 | 719.38±27.49 | 969.04±46.21 | 915.46±52.70 | 1601.80±38.00 | |
3200 tokens | Base MT | 1526.37±43.21 | 1488.71±78.56 | 1665.54±66.93 | 1885.99±13.26 | 2645.62±80.18 |
kNN-MT | 1168.07±20.86 | 1051.21±30.82 | 1395.36±63.48 | 1547.67±60.08 | 2040.28±29.90 | |
Adaptive kNN-MT | 1135.30±63.46 | 1037.96±54.62 | 1335.45±60.56 | 1442.43±52.53 | 2032.88±47.17 | |
6400 tokens | Base MT | 2078.05±14.57 | 2038.81±60.04 | 2078.64±55.91 | 2397.98±11.12 | 2838.64±12.76 |
kNN-MT | 1541.41±31.89 | 1337.22±5.74 | 1698.17±46.67 | 1965.55±43.59 | 2176.18±26.11 | |
Adaptive kNN-MT | 1494.57±22.87 | 1326.34±24.34 | 1695.56±42.75 | 1902.53±45.91 | 2173.67±25.10 | |
12800 tokens | Base MT | 2377.90±20.36 | 2374.11±6.77 | 2158.86±21.50 | 2589.23±40.78 | 2986.30±31.20 |
kNN-MT | 1752.04±11.44 | 1493.63±5.76 | 1772.20±51.73 | 2175.42±40.24 | 2314.58±6.86 | |
Adaptive kNN-MT | 1719.02±36.40 | 1476.38±13.23 | 1765.07±47.39 | 2117.49±45.74 | 2313.21±44.98 |
It is nontrivial to accurately measure the minimum amount of GPU memory to support model inference
due to the complicated GPU memory management of PyTorch
and Faiss. Nevertheless,
to report the approximate minimum GPU memory requirement for inference, we disabled the memory caching
of PyTorch by setting the value of the environment variable PYTORCH_NO_CUDA_MEMORY_CACHING
to 1
and monitored the maximum amount of used GPU memory every 10 milliseconds.
We set the batch size to 12000 tokens to follow the default setting of Fairseq for experiments.
The observed maximum GPU memory consumption of kNN-models during inference is presented below:
Batch Size | Medical | Law | IT | Koran | Subtitles | |
---|---|---|---|---|---|---|
12000 tokens | Base MT | 6363 MB | 6519 MB | 6509 MB | 6575 MB | 6349 MB |
kNN-MT | 8391 MB | 9383 MB | 8255 MB | 8155 MB | 8367 MB | |
Adaptive kNN-MT | 8379 MB | 9403 MB | 8265 MB | 8153 MB | 8375 MB |
Most studies of retrieval argumented sequence modeling implemented in kNN-models mainly rely on two procedures to work:
(1) saving the intermediate hidden states of the model during the forward pass, (2) retrieving useful information from a datastore
according to the saved intermediate hidden states to improve the probability distribution over tokens. For the first procedure,
we register a forward hook
on the model to collect the intermediate hidden states, which is implemented in the build_model
function (please refer
here for more details). For
the second procedure, as Fairseq calls the get_normalized_probs
function to obtain the probability over tokens for almost all
sequence-to-sequence generation models (please refer
here for more details), we rewrite
the get_normalized_probs
function to incorporate the retrieval results into the model. Consequently, the detailed model
implementation lies in the rewrited get_normalized_probs
function.
Running build_faiss_index
is equivalent to running the build_faiss_index.py
script in the knn_models_cli folder
and count_tokens
works in a similar way. Specifically, We use the console_scripts
entry point to provide command line tools and
all the commands in kNN-models can be found in setup.py.
How to make the GPU used by faiss different from the model during training or inference?
You can control the GPU used by faiss by setting the
--knn-device-id
flag (e.g., --knn-device-id 1
). Additionally, please also add the --distributed-world-size
flag (e.g.,
--distributed-world-size 1
) during training to ensure that the GPU used by faiss is
not occupied by the model.
We are extremely grateful to the research communities for their incredible work on retrieval argumented sequence modeling. This repository would not have been possible without them. Furthermore, we would also like to thank wls for his generous help and valuable suggestions in replicating the PCKMT.