8000 duo documentation, cleanup mono documentation by ronakice · Pull Request #146 · castorini/pygaggle · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

duo documentation, cleanup mono documentation #146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jan 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 242 additions & 0 deletions docs/experiments-duot5-tpu.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
# Neural Pairwise Ranking Baselines on [MS MARCO Passage Retrieval](https://github.com/microsoft/MSMARCO-Passage-Ranking) - with TPU

This page contains instructions for running duoT5 on the MS MARCO *passage* ranking task.

We will focus on using duoT5-3B to rerank, since it is difficult to run such a large model without a TPU.
We also mention the changes required to run duoT5-base for those with a more constrained compute budget.
- duoT5: The Expando-Mono-Duo Design Pattern for Text Ranking with Pretrained Sequence-to-Sequence Models [(Pradeep et al., 2021)](https://arxiv.org/pdf/2101.05667.pdf)

Note that there are also separate documents to run MS MARCO ranking tasks on regular GPU. Please see [MS MARCO *document* ranking task](https://github.com/castorini/pygaggle/blob/master/docs/experiments-msmarco-document.md), [MS MARCO *passage* ranking task - Subset](https://github.com/castorini/pygaggle/blob/master/docs/experiments-msmarco-passage-subset.md) and [MS MARCO *passage* ranking task - Entire](https://github.com/castorini/pygaggle/blob/master/docs/experiments-msmarco-passage-entire.md).

Prior to running this, we suggest looking at our second-stage [pointwise ranking instructions](https://github.com/castorini/pygaggle/blob/master/docs/experiments-monot5-tpu.md).
We rerank the monoT5 run files that contain ~1000 passages per query (of which we'll focus on the top 50 passages) using duoT5.
duo5 is a pairwise reranker.
This means that the reranker estimates the probability that a document is more relevant than another.
These scores are aggregated to get a single score for each document.

## Data Prep

Since we will use some scripts form PyGaggle to process data and evaluate results, we first install it from source.
```
git clone --recursive https://github.com/castorini/pygaggle.git
cd pygaggle
pip install .
```

We store all the files in the `data/msmarco_passage` directory.
```
export DATA_DIR=data/msmarco_passage
mkdir ${DATA_DIR}
```

We provide specific data prep instructions for evaluating on the dev set.

### Dev Set

We download the query, qrels, and corpus files corresponding to the MS MARCO passage dev set.

The run file is generated by following the PyGaggle's [monoT5 TPU instructions](https://github.com/castorini/pygaggle/blob/master/docs/experiments-monot5-tpu.md).

In short, the files are:
- `queries.dev.small.tsv`: 6,980 queries from the MS MARCO dev set.
- `qrels.dev.small.tsv`: 7,437 pairs of query relevant passage ids from the MS MARCO dev set.
- `collection.tar.gz`: All passages (8,841,823) in the MS MARCO passage corpus. In this tsv file, the first column is the passage id, and the second is the passage text.

A more detailed description of the data is available [here](https://github.com/castorini/duobert#data-and-trained-models).

Let's start.
```
cd ${DATA_DIR}
wget https://www.dropbox.com/s/hq6xjhswiz60siu/queries.dev.small.tsv
wget https://www.dropbox.com/s/5t6e2225rt6ikym/qrels.dev.small.tsv
wget https://www.dropbox.com/s/m1n2wf80l1lb9j1/collection.tar.gz
tar -xvf collection.tar.gz
rm collection.tar.gz
cd ../../
```

As a sanity check, we can evaluate the second-stage retrieved documents using the official MS MARCO evaluation script.
We choose one of the monoT5-base run file to rerank with duoT5-base and the monoT5-3B run file to rerank with duoT5-3B.
```
export MODEL_NAME=<base or 3B>
python tools/eval/msmarco_eval.py ${DATA_DIR}/qrels.dev.small.tsv ${DATA_DIR}/run.monot5_${MODEL_NAME}.dev.tsv
```

In the case of monoT5-3B, the output should be:

```
#####################
MRR @10: 0.3983799517896949
QueriesRanked: 6980
#####################
```

In the case of monoT5-base, the output should be:

```
#####################
MRR @10: 0.38160657433938283
QueriesRanked: 6980
#####################
```

Then, we prepare the query-doc0-doc1 pairs in the duoT5 input format.
```
python pygaggle/data/create_msmarco_duot5_input.py --queries ${DATA_DIR}/queries.dev.small.tsv \
--run ${DATA_DIR}/run.monot5_${MODEL_NAME}.dev.tsv \
--corpus ${DATA_DIR}/collection.tsv \
--t5_input ${DATA_DIR}/query_docs_triples.dev.small.txt \
--t5_input_ids ${DATA_DIR}/query_docs_triple_ids.dev.small.tsv \
--top_k 50
```
We will get two output files here:
- `query_docs_triples.dev.small.txt`: The query-doc0-doc1 triples for duoT5 input.
- `query_docs_triple_ids.dev.small.tsv`: The `query_id`s,`doc_id_0`s, and `doc_id_1`s that map to the query-doc0-doc1 triples. We will use this to map query-doc0-doc1 triples to their corresponding duoT5 output scores.

The files are made available in our [bucket](https://console.cloud.google.com/storage/browser/castorini/duot5/data).

Note that there might be a memory issue if the duoT5 input file is too large for the memory in the instance. We thus split the input file into multiple files.

```
split --suffix-length 3 --numeric-suffixes --lines 500000 ${DATA_DIR}/query_docs_triples.dev.small.txt ${DATA_DIR}/query_docs_triples.dev.small.txt
```

For `query_docs_triples.dev.small.txt`, we will get 35 files after split. i.e. (`query_docs_triples.dev.small.txt000` to `query_docs_triples.dev.small.txt034`).
Note that it is possible that running reranking might still result in OOM issues in which case reduce the number of lines to smaller than `500000`.

We copy these input files to Google Storage. TPU inference will read data directly from `gs`.
```
export GS_FOLDER=<google storage folder to store input/output data>
gsutil cp ${DATA_DIR}/query_docs_triples.dev.small.txt??? ${GS_FOLDER}
```

## Start a VM with TPU on Google Cloud

Define environment variables.
```
export PROJECT_NAME=<gcloud project name>
export PROJECT_ID=<gcloud project id>
export INSTANCE_NAME=<name of vm to create>
export TPU_NAME=<name of tpu to create>
```

Create the VM.
```
gcloud beta compute --project=${PROJECT_NAME} instances create ${INSTANCE_NAME} --zone=europe-west4-a --machine-type=n1-standard-4 --subnet=default --network-tier=PREMIUM --maintenance-policy=MIGRATE --service-account=${PROJECT_ID}-compute@developer.gserviceaccount.com --scopes=https://www.googleapis.com/auth/cloud-platform --image=debian-10-buster-v20201112 --image-project=debian-cloud --boot-disk-size=25GB --boot-disk-type=pd-standard --boot-disk-device-name=${INSTANCE_NAME} --reservation-affinity=any
```

It is possible that the `image` and `machine-type` provided here are dated so feel free to update them to whichever fits your needs.
After the VM created, we can `ssh` to the machine.
Make sure to initialize `PROJECT_NAME` and `TPU_NAME` from within the machine too.
Then create a TPU.

```
curl -O https://dl.google.com/cloud_tpu/ctpu/latest/linux/ctpu && chmod a+x ctpu
./ctpu up --name=${TPU_NAME} --project=${PROJECT_NAME} --zone=europe-west4-a --tpu-size=v3-8 --tpu-only --noconf
```

## Setup environment on VM

Install required tools including [Miniconda](https://docs.conda.io/en/latest/miniconda.html).
```
sudo apt-get update
sudo apt-get install git gcc screen --yes
curl -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash ./Miniconda3-latest-Linux-x86_64.sh
source ~/.bashrc
```
Then create a Python virtual environment for the experiments and install dependencies.
```
conda init
conda create --y --name py36 python=3.6
conda activate py36
conda install -c conda-forge httptools jsonnet --yes
pip install tensorflow tensorflow-text t5[gcp]
git clone https://github.com/castorini/mesh.git
pip install --editable mesh
```

## Rerank with monoT5

Let's first define the model type and checkpoint.

```
export MODEL_NAME=<base or 3B>
export MODEL_DIR=gs://castorini/duot5/experiments/${MODEL_NAME}
```

Then run following command to start the process in background and monitor the log
```
for ITER in {000..034}; do
echo "Running iter: $ITER" >> out.log_eval_exp
nohup t5_mesh_transformer \
--tpu="${TPU_NAME}" \
--gcp_project="${PROJECT_NAME}" \
--tpu_zone="europe-west4-a" \
--model_dir="${MODEL_DIR}" \
--gin_file="gs://t5-data/pretrained_models/${MODEL_NAME}/operative_config.gin" \
--gin_file="infer.gin" \
--gin_file="beam_search.gin" \
--gin_param="utils.tpu_mesh_shape.tpu_topology = '2x2'" \
--gin_param="infer_checkpoint_step = 1150000" \
--gin_param="utils.run.sequence_length = {'inputs': 512, 'targets': 2}" \
--gin_param="Bitransformer.decode.max_decode_length = 2" \
--gin_param="input_filename = '${GS_FOLDER}/query_docs_triples.dev.small.txt${ITER}'" \
--gin_param="output_filename = '${GS_FOLDER}/query_docs_triple_scores.dev.small.txt${ITER}'" \
--gin_param="tokens_per_batch = 65536" \
--gin_param="Bitransformer.decode.beam_size = 1" \
--gin_param="Bitransformer.decode.temperature = 0.0" \
--gin_param="Unitransformer.sample_autoregressive.sampling_keep_top_k = -1" \
>> out.log_eval_exp 2>&1
done &

tail -100f out.log_eval_exp
```

Using a TPU v3-8, it takes approximately 12 hours and 38 hours to rerank with duoT5-base and duoT5-3B respectively.

Note that we strongly encourage you to run any of the long processes in `screen` to make sure they don't get interrupted.

## Evaluate reranked results
After reranking is done, let's copy the results from GS to our working directory, where we concatenate all the score files back into one file.
```
gsutil cp ${GS_FOLDER}/query_docs_triple_scores.dev.small.txt???-1150000 ${DATA_DIR}/
cat ${DATA_DIR}/query_docs_triple_scores.dev.small.txt???-1150000 > ${DATA_DIR}/query_docs_triple_scores.dev.small.txt
```

Then we convert the duoT5 output to the required MSMARCO format.
```
python pygaggle/data/convert_duot5_output_to_msmarco_run.py --t5_output ${DATA_DIR}/query_docs_triple_scores.dev.small.txt \
--t5_output_ids ${DATA_DIR}/query_docs_triple_ids.dev.small.tsv \
--duo_run ${DATA_DIR}/run.duot5_${MODEL_NAME}.dev.tsv \
--input_run ${DATA_DIR}/run.monot5_${MODEL_NAME}.dev.tsv \
--aggregate sym-sum
```

Now we can evaluate the reranked results using the official MS MARCO evaluation script.
```
python tools/eval/msmarco_eval.py ${DATA_DIR}/qrels.dev.small.tsv ${DATA_DIR}/run.duot5_${MODEL_NAME}.dev.tsv
```

In the case of duoT5-3B, the output should be:

```
#####################
MRR @10: 0.40913556874516793
QueriesRanked: 6980
#####################
```

In the case of duoT5-base, the output should be:

```
#####################
MRR @10: 0.3929155864829223
QueriesRanked: 6980
#####################
```

If you were able to replicate any of these results, please submit a PR adding to the replication log, along with the model(s) you replicated.
Please mention in your PR if you note any differences.

## Replication Log
8 changes: 4 additions & 4 deletions docs/experiments-monot5-tpu.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Neural Ranking Baselines on [MS MARCO Passage Retrieval](https://github.com/microsoft/MSMARCO-Passage-Ranking) - with TPU
# Neural Pointwise Ranking Baselines on [MS MARCO Passage Retrieval](https://github.com/microsoft/MSMARCO-Passage-Ranking) - with TPU

This page contains instructions for running monoT5 on the MS MARCO *passage* ranking task.

Expand Down Expand Up @@ -183,7 +183,7 @@ for ITER in {000..008}; do
echo "Running iter: $ITER" >> out.log_eval_exp
nohup t5_mesh_transformer \
--tpu="${TPU_NAME}" \
--gcp_project=${PROJECT_NAME} \
--gcp_project="${PROJECT_NAME}" \
--tpu_zone="europe-west4-a" \
--model_dir="${MODEL_DIR}" \
--gin_file="gs://t5-data/pretrained_models/${MODEL_NAME}/operative_config.gin" \
Expand Down Expand Up @@ -218,9 +218,9 @@ cat ${DATA_DIR}/query_doc_pair_scores.dev.small.txt???-1100000 > ${DATA_DIR}/que

Then we convert the monoT5 output to the required MSMARCO format.
```
python pygaggle/data/convert_t5_output_to_msmarco_run.py --t5_output ${DATA_DIR}/query_doc_pair_scores.dev.small.txt \
python pygaggle/data/convert_monot5_output_to_msmarco_run.py --t5_output ${DATA_DIR}/query_doc_pair_scores.dev.small.txt \
--t5_output_ids ${DATA_DIR}/query_doc_pair_ids.dev.small.tsv \
--msmarco_run ${DATA_DIR}/run.monot5_${MODEL_NAME}.dev.tsv
--mono_run ${DATA_DIR}/run.monot5_${MODEL_NAME}.dev.tsv
```

Now we can evaluate the reranked results using the official MS MARCO evaluation script.
Expand Down
86 changes: 86 additions & 0 deletions pygaggle/data/convert_duot5_output_to_msmarco_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
This script convert duoT5 output file to msmarco run file
"""
import argparse
import collections
import numpy as np
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument("--t5_output", type=str, required=True,
help="tsv file with two columns, <label> and <score>")
parser.add_argument("--t5_output_ids", type=str, required=True,
help="tsv file with five columns <query_id>, <doc_id_a>, <doc_id_b>, <rank_a> and <rank_b>")
parser.add_argument("--input_run", type=str, required=True,
help="path to input run, tsv file, with <query_id>, <doc_id> and <rank>")
parser.add_argument("--duo_run", type=str, required=True,
help="path to output duo run, tsv file, with <query_id>, <doc_id> and <rank>")
parser.add_argument("--top_k", type=int, default=50,
help="top-k pointwise hits to be reranked by pairwise ranker")
parser.add_argument("--aggregate", type=str, default="sym_sum",
help="aggregation technique: one of sum, sym_sum, log_sum or sym_log_sum")

args = parser.parse_args()


def load_run(path):
"""Loads run into a dict of key: query_id, value: list of candidate doc
ids."""

# We want to preserve the order of runs so we can pair the run file with
# the TFRecord file.
print('Loading run...')
run = collections.OrderedDict()
with open(path) as f:
for line in tqdm(f):
query_id, doc_title, rank = line.split('\t')
if query_id not in run:
run[query_id] = []
run[query_id].append((doc_title, int(rank)))

# Sort candidate docs by rank.
print('Sorting candidate docs by rank...')
sorted_run = collections.OrderedDict()
for query_id, doc_titles_ranks in tqdm(run.items()):
sorted(doc_titles_ranks, key=lambda x: x[1])
doc_titles = [doc_titles for doc_titles, _ in doc_titles_ranks]
sorted_run[query_id] = doc_titles

return sorted_run


input_run = load_run(path=args.input_run)
examples = collections.defaultdict(dict)
with open(args.t5_output_ids) as f_gt, open(args.t5_output) as f_pred:
for line_gt, line_pred in zip(f_gt, f_pred):
query_id, doc_id_a, doc_id_b, ct_a, ct_b = line_gt.strip().split('\t')
_, score = line_pred.strip().split('\t')
score = float(score)
if int(ct_a) < args.top_k and int(ct_b) < args.top_k:
if doc_id_a not in examples[query_id]:
examples[query_id][doc_id_a] = 0
if "log" not in args.aggregate:
score = np.exp(score)
examples[query_id][doc_id_a] += score
if "sym" in args.aggregate:
if doc_id_b not in examples[query_id]:
examples[query_id][doc_id_b] = 0
if "log" in args.aggregate:
score_b = np.log(1 - np.exp(score))
else:
score_b = 1 - score
examples[query_id][doc_id_b] += score_b

for qid in examples:
examples[qid] = list(examples[qid].items())

with open(args.duo_run, 'w') as fout:
for query_id, doc_ids_scores in examples.items():
doc_ids_scores.sort(key=lambda x: x[1], reverse=True)
for rank, (doc_id, _) in enumerate(doc_ids_scores):
fout.write(f'{query_id}\t{doc_id}\t{rank + 1}\n')
input_offset = len(doc_ids_scores)
for rank, doc_id in enumerate(input_run[query_id]):
if rank < input_offset:
continue
fout.write(f'{query_id}\t{doc_id}\t{rank + 1}\n')
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
help="tsv file with two columns, <label> and <score>")
parser.add_argument("--t5_output_ids", type=str, required=True,
help="tsv file with two columns, <query_id> and <doc_id>")
parser.add_argument("--msmarco_run", type=str, required=True,
help="path to msmarco_run, tsv file, with <query_id>, <doc_id> and <rank>")
parser.add_argument("--mono_run", type=str, required=True,
help="path to output mono run, tsv file, with <query_id>, <doc_id> and <rank>")
args = parser.parse_args()

examples = collections.defaultdict(list)
Expand All @@ -21,7 +21,7 @@
score = float(score)
examples[query_id].append((doc_id, score))

with open(args.msmarco_run, 'w') as fout:
with open(args.mono_run, 'w') as fout:
for query_id, doc_ids_scores in examples.items():
doc_ids_scores.sort(key=lambda x: x[1], reverse=True)
for rank, (doc_id, _) in enumerate(doc_ids_scores):
Expand Down
Loading
0