10000 Adding duo T5 by wiltan-uw · Pull Request #127 · castorini/pygaggle · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Adding duo T5 #127

New issue

Have a question about this project? Sign up for a free GitHu 8000 b account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion pygaggle/model/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from pygaggle.data.segmentation import SegmentProcessor

__all__ = ['RerankerEvaluator', 'metric_names']
__all__ = ['RerankerEvaluator', 'DuoRerankerEvaluator', 'metric_names']
METRIC_MAP = OrderedDict()


Expand Down Expand Up @@ -184,3 +184,58 @@ def evaluate_by_segments(self,
for metric in metrics:
metric.accumulate(doc_scores, example)
return metrics


class DuoRerankerEvaluator:
def __init__(self,
mono_reranker: Reranker,
duo_reranker: Reranker,
metric_names: List[str],
mono_hits: int = 50,
use_tqdm: bool = True,
writer: Optional[Writer] = None):
self.mono_reranker = mono_reranker
self.duo_reranker = duo_reranker
self.mono_hits = mono_hits
self.metrics = [METRIC_MAP[name] for name in metric_names]
self.use_tqdm = use_tqdm
self.writer = writer

def evaluate(self,
examples: List[RelevanceExample]) -> List[MetricAccumulator]:
metrics = [cls() for cls in self.metrics]
mono_texts = []
scores = []
for ct, example in tqdm(enumerate(examples), total=len(examples), disable=not self.use_tqdm):
mono_out = self.mono_reranker.rerank(example.query, example.documents)
mono_texts.append(sorted(enumerate(mono_out), key=lambda x: x[1].score, reverse=True)[:self.mono_hits])
scores.append(np.array([x.score for x in mono_out]))
for ct, texts in tqdm(enumerate(mono_texts), total=len(mono_texts), disable=not self.use_tqdm):
duo_in = list(map(lambda x: x[1], texts))
duo_scores = [x.score for x in self.duo_reranker.rerank(examples[ct].query, duo_in)]

scores[ct][list(map(lambda x: x[0], texts))] = duo_scores
if self.writer is not None:
self.writer.write(list(scores[ct]), examples[ct])
for metric in metrics:
metric.accumulate(list(scores[ct]), examples[ct])
return metrics

def evaluate_by_segments(self,
examples: List[RelevanceExample],
seg_size: int,
stride: int,
aggregate_method: str) -> List[MetricAccumulator]:
metrics = [cls() for cls in self.metrics]
segment_processor = SegmentProcessor()
for example in tqdm(examples, disable=not self.use_tqdm):
segment_group = segment_processor.segment(example.documents, seg_size, stride)
segment_group.segments = self.reranker.rerank(example.query, segment_group.segments)
doc_scores = [x.score for x in segment_processor.aggregate(example.documents,
segment_group,
aggregate_method)]
if self.writer is not None:
self.writer.write(doc_scores, example)
for metric in metrics:
metric.accumulate(doc_scores, example)
return metrics
37 changes: 36 additions & 1 deletion pygaggle/model/tokenize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Mapping, Union, Iterable, Optional
from typing import List, Mapping, Union, Iterable, Optional, Tuple

from spacy.lang.en import English
from transformers import PreTrainedTokenizer
Expand All @@ -11,7 +11,9 @@

__all__ = ['BatchTokenizer',
'T5BatchTokenizer',
'T5DuoBatchTokenizer',
'QueryDocumentBatch',
'DuoQueryDocumentBatch',
'SimpleBatchTokenizer',
'QueryDocumentBatchTokenizer',
'SpacySenticizer',
Expand Down Expand Up @@ -40,6 +42,16 @@ def __len__(self):
return len(self.documents)


@dataclass
class DuoQueryDocumentBatch:
query: Query
doc_pairs: List[Tuple[Text, Text]]
output: Optional[TokenizerReturnType] = None

def __len__(self):
return len(self.doc_pairs)


class TokenizerEncodeMixin:
tokenizer: PreTrainedTokenizer = None
tokenizer_kwargs = None
Expand Down Expand Up @@ -105,6 +117,18 @@ def traverse_query_document(
document=doc.text) for doc in docs])
yield QueryDocumentBatch(query, docs, outputs)

def traverse_duo_query_document(
self,
batch_input: DuoQueryDocumentBatch) -> Iterable[DuoQueryDocumentBatch]:
query = batch_input.query
for batch_idx in range(0, len(batch_input), self.batch_size):
docs = batch_input.doc_pairs[batch_idx:batch_idx + self.batch_size]
outputs = self.encode([self.pattern.format(
query=query.text,
document0=doc[0].text,
document1=doc[1].text) for doc in docs])
yield DuoQueryDocumentBatch(query, docs, outputs)


class T5BatchTokenizer(AppendEosTokenizerMixin, QueryDocumentBatchTokenizer):
def __init__(self, *args, **kwargs):
Expand All @@ -117,6 +141,17 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)


class T5DuoBatchTokenizer(AppendEosTokenizerMixin, QueryDocumentBatchTokenizer):
def __init__(self, *args, **kwargs):
kwargs['pattern'] = 'Query: {query} Document0: {document0} Document1: {document1} Relevant:'
kwargs['return_attention_mask'] = True
kwargs['padding'] = 'longest'
kwargs["truncation"] = True
kwargs['return_tensors'] = 'pt'
kwargs['max_length'] = 512
super().__init__(*args, **kwargs)


class SimpleBatchTokenizer(BatchTokenizer):
def __init__(self, *args, **kwargs):
kwargs['return_attention_mask'] = True
Expand Down
55 changes: 55 additions & 0 deletions pygaggle/rerank/transformer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from collections import defaultdict
from copy import deepcopy
from itertools import permutations
from typing import List

from transformers import (AutoTokenizer,
Expand All @@ -12,13 +14,16 @@
from pygaggle.model import (BatchTokenizer,
LongBatchEncoder,
QueryDocumentBatch,
DuoQueryDocumentBatch,
QueryDocumentBatchTokenizer,
SpecialTokensCleaner,
T5BatchTokenizer,
T5DuoBatchTokenizer,
greedy_decode)


__all__ = ['MonoT5',
'DuoT5',
'UnsupervisedTransformerReranker',
'MonoBERT',
'QuestionAnsweringTransformerReranker']
Expand Down Expand Up @@ -68,6 +73,56 @@ def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
return texts


class DuoT5(Reranker):
def __init__(self,
model: T5ForConditionalGeneration = None,
tokenizer: QueryDocumentBatchTokenizer = None):
self.model = model or self.get_model()
self.tokenizer = tokenizer or self.get_tokenizer()
self.device = next(self.model.parameters(), None).device

@staticmethod
def get_model(pretrained_model_name_or_path: str = 'castorini/duot5-base-msmarco',
*args, device: str = None, **kwargs) -> T5ForConditionalGeneration:
device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device(device)
return T5ForConditionalGeneration.from_pretrained(pretrained_model_name_or_path, *args, **kwargs).to(device).eval()

@staticmethod
def get_tokenizer(pretrained_model_name_or_path: str = 't5-base',
*args, batch_size: int = 8, **kwargs) -> T5DuoBatchTokenizer:
return T5DuoBatchTokenizer(
AutoTokenizer.from_pretrained(pretrained_model_name_or_path, use_fast=False, *args, **kwargs),
batch_size=batch_size
)

def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
texts = deepcopy(texts)
doc_pairs = list(permutations(texts, 2))
scores = defaultdict(float)
batch_input = DuoQueryDocumentBatch(query=query, doc_pairs=doc_pairs)
for batch in self.tokenizer.traverse_duo_query_document(batch_input):
input_ids = batch.output['input_ids'].to(self.device)
attn_mask = batch.output['attention_mask'].to(self.device)
_, batch_scores = greedy_decode(self.model,
input_ids,
length=1,
attention_mask=attn_mask,
return_last_logits=True)

# 6136 and 1176 are the indexes of the tokens false and true in T5.
batch_scores = batch_scores[:, [6136, 1176]]
batch_scores = torch.nn.functional.softmax(batch_scores, dim=1)
batch_probs = batch_scores[:, 1].tolist()
for doc, score in zip(batch.doc_pairs, batch_probs):
scores[doc[0].metadata['docid']] += score
scores[doc[1].metadata['docid']] += (1 - score)

for text in texts:
text.score = scores[text.metadata['docid']]
return texts


class UnsupervisedTransformerReranker(Reranker):
methods = dict(max=lambda x: x.max().item(),
mean=lambda x: x.mean().item(),
Expand Down
34 changes: 31 additions & 3 deletions pygaggle/run/evaluate_passage_ranker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, List
from typing import Optional, List, Tuple
from pathlib import Path
import logging

Expand All @@ -16,13 +16,15 @@
from pygaggle.rerank.transformer import (
UnsupervisedTransformerReranker,
MonoT5,
DuoT5,
MonoBERT
)
from pygaggle.rerank.random import RandomReranker
from pygaggle.rerank.similarity import CosineSimilarityMatrixProvider
from pygaggle.model import (SimpleBatchTokenizer,
T5BatchTokenizer,
RerankerEvaluator,
DuoRerankerEvaluator,
metric_names,
MsMarcoWriter)
from pygaggle.data import MsMarcoDataset
Expand All @@ -31,7 +33,7 @@

SETTINGS = MsMarcoSettings()
METHOD_CHOICES = ('transformer', 'bm25', 't5', 'seq_class_transformer',
'random')
'random', 'duo_t5')


class PassageRankingEvaluationOptions(BaseModel):
Expand All @@ -40,6 +42,8 @@ class PassageRankingEvaluationOptions(BaseModel):
index_dir: Path
method: str
model: str
duo_model: str
mono_hits: int
split: str
batch_size: int
device: str
Expand Down Expand Up @@ -85,6 +89,15 @@ def construct_t5(options: PassageRankingEvaluationOptions) -> Reranker:
return MonoT5(model, tokenizer)


def construct_duo_t5(options: PassageRankingEvaluationOptions) -> Tuple[Reranker, Reranker]:
mono_reranker = construct_t5(options)
model = DuoT5.get_model(options.duo_model,
from_tf=options.from_tf,
device=options.device)
tokenizer = DuoT5.get_tokenizer(options.model_type, batch_size=options.batch_size)
return mono_reranker, DuoT5(model, tokenizer)


def construct_transformer(options:
PassageRankingEvaluationOptions) -> Reranker:
device = torch.device(options.device)
Expand Down Expand Up @@ -137,6 +150,13 @@ def main():
required=True,
type=str,
help='Path to pre-trained model or huggingface model name'),
opt('--duo_model',
type=str,
help='Path to pre-trained model or huggingface model name'),
opt('--mono_hits',
type=int,
default=50,
help='Top k candidates from mono for duo reranking'),
opt('--output-file', type=Path, default='.'),
opt('--overwrite-output', action='store_true'),
opt('--split',
Expand Down Expand Up @@ -165,11 +185,19 @@ def main():
construct_map = dict(transformer=construct_transformer,
bm25=construct_bm25,
t5=construct_t5,
duo_t5=construct_duo_t5,
seq_class_transformer=construct_seq_class_transformer,
random=lambda _: RandomReranker())
reranker = construct_map[options.method](options)
writer = MsMarcoWriter(args.output_file, args.overwrite_output)
evaluator = RerankerEvaluator(reranker, options.metrics, writer=writer)
if options.method == 'duo_t5':
evaluator = DuoRerankerEvaluator(mono_reranker=reranker[0],
duo_reranker=reranker[1],
metric_names=options.metrics,
mono_hits=options.mono_hits,
writer=writer)
else:
evaluator = RerankerEvaluator(reranker, options.metrics, writer=writer)
width = max(map(len, args.metrics)) + 1
logging.info("Reranking:")
for metric in evaluator.evaluate(examples):
Expand Down
0