8000 feat: ranking evaluate driver may need more than one field by JoanFM · Pull Request #1953 · jina-ai/serve · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

feat: ranking evaluate driver may need more than one field #1953

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 37 additions & 21 deletions jina/drivers/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved."
__license__ = "Apache-2.0"

from typing import Any, Iterator
from typing import Any, Iterator, Optional, Tuple, Union

import numpy
import numpy as np

from . import BaseExecutableDriver, RecursiveMixin
from ..types.querylang.queryset.dunderkey import dunder_get
from .search import KVSearchDriver
from ..types.document import Document
from ..types.document.helper import DocGroundtruthPair
from ..helper import deprecated_alias


class BaseEvaluateDriver(RecursiveMixin, BaseExecutableDriver):
Expand All @@ -27,7 +28,8 @@ class BaseEvaluateDriver(RecursiveMixin, BaseExecutableDriver):
:param *args:
:param **kwargs:
"""
def __init__(self, executor: str = None,

def __init__(self, executor: Optional[str] = None,
method: str = 'evaluate',
running_avg: bool = False,
*args,
Expand Down Expand Up @@ -84,7 +86,8 @@ class FieldEvaluateDriver(BaseEvaluateDriver):
Evaluate on the values from certain field, the extraction is implemented with :meth:`dunder_get`
"""

def __init__(self, field: str,
def __init__(self,
field: str,
*args,
**kwargs):
"""
Expand All @@ -105,33 +108,46 @@ def extract(self, doc: 'Document') -> Any:
return dunder_get(doc, self.field)


class RankEvaluateDriver(FieldEvaluateDriver):
class RankEvaluateDriver(BaseEvaluateDriver):
"""Drivers used to pass `matches` from documents and groundtruths to an executor and add the evaluation value

- Example fields:
['tags__id', 'id', 'score__value]
['tags__id', 'score__value]

:param fields: the fields names to be extracted from the Protobuf.
The differences with `:class:FieldEvaluateDriver` are:
- More than one field is allowed. For instance, for NDCGComputation you may need to have both `ID` and `Relevance` information.
- The fields are extracted from the `matches` of the `Documents` and the `Groundtruth` so it returns a sequence of values.
:param *args:
:param **kwargs:
"""

@deprecated_alias(field=('fields', 0))
def __init__(self,
field: str = 'tags__id',
fields: Union[str, Tuple[str]] = ('tags__id',), # str mantained for backwards compatibility
*args,
**kwargs):
"""
:param field: the field name to be extracted from the Protobuf
:param *args: *args for super()
:param **kwargs: **kwargs for super()
"""
super().__init__(field, *args, **kwargs)
super().__init__(*args, **kwargs)
self.fields = fields

def extract(self, doc: 'Document'):
"""Extract the field from the Document's matches.
@property
def single_field(self):
if isinstance(self.fields, str):
return self.fields
elif len(self.fields) == 1:
return self.fields[0]

:param doc: the Document
:return: list of the fields
"""
r = [dunder_get(x, self.field) for x in doc.matches]
# flatten nested list but useless depth, e.g. [[1,2,3,4]]
return list(numpy.array(r).flat)
def extract(self, doc: 'Document'):
single_field = self.single_field
if single_field:
r = [dunder_get(x, single_field) for x in doc.matches]
# TODO: Clean this, optimization for `hello-world` because it passes a list of 6k elements in a single
# match. See `pseudo_match` in helloworld/helper.py _get_groundtruths
ret = list(np.array(r).flat)
else:
ret = [tuple(dunder_get(x, field) for field in self.fields) for x in doc.matches]

return ret


class NDArrayEvaluateDriver(FieldEvaluateDriver):
Expand Down
2 changes: 1 addition & 1 deletion jina/resources/executors._eval_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ requests:
with:
traversal_paths: ['r']
running_avg: true
field: tags__id
fields: [tags__id]
drivers:
- !RankEvaluateDriver
with:
Expand Down
3 changes: 2 additions & 1 deletion jina/resources/executors.requests.BaseRankingEvaluator.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ on:
- text
- !RankEvaluateDriver
with:
id_tag: 'id'
fields: ['id']

39 changes: 29 additions & 10 deletions tests/unit/drivers/test_rankingevaluation_driver.py
mocker):
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Tuple

import pytest

from jina.drivers.evaluate import RankEvaluateDriver
Expand All @@ -8,8 +10,8 @@

class SimpleRankEvaluateDriver(RankEvaluateDriver):

def __init__(self, field: str, *args, **kwargs):
super().__init__(field, *args, **kwargs)
def __init__(self, fields: Tuple[str], *args, **kwargs):
super().__init__(fields, *args, **kwargs)

@property
def exec_fn(self):
Expand All @@ -22,8 +24,8 @@ def expect_parts(self) -> int:

class RunningAvgRankEvaluateDriver(RankEvaluateDriver):

def __init__(self, field: str, *args, **kwargs):
super().__init__(field, runining_avg=True, *args, **kwargs)
def __init__(self, fields: Tuple[str], *args, **kwargs):
super().__init__(fields, runining_avg=True, *args, **kwargs)

@property
def exec_fn(self):
Expand All @@ -35,13 +37,13 @@ def expect_parts(self) -> int:


@pytest.fixture
def simple_rank_evaluate_driver(field):
return SimpleRankEvaluateDriver(field)
def simple_rank_evaluate_driver(fields):
return SimpleRankEvaluateDriver(fields)


@pytest.fixture
def runningavg_rank_evaluate_driver(field):
return RunningAvgRankEvaluateDriver(field)
def runningavg_rank_evaluate_driver(fields):
return RunningAvgRankEvaluateDriver(fields)


@pytest.fixture
Expand All @@ -64,7 +66,7 @@ def add_matches(doc: jina_pb2.DocumentProto, num_matches):
return pairs


@pytest.mark.parametrize('field', ['tags__id', 'score__value'])
@pytest.mark.parametrize('fields', [('tags__id',), ('score__value',), 'tags__id', 'score__value'])
def test_ranking_evaluate_simple_driver(simple_rank_evaluate_driver,
ground_truth_pairs):
simple_rank_evaluate_driver.attach(executor=PrecisionEvaluator(eval_at=2), runtime=None)
Expand All @@ -76,7 +78,24 @@ def test_ranking_evaluate_simple_driver(simple_rank_evaluate_driver,
assert doc.evaluations[0].value == 1.0


@pytest.mark.parametrize('field', ['tags__id', 'score__value'])
@pytest.mark.parametrize('fields', [('tags__id', 'score__value')])
def test_ranking_evaluate_extract_multiple_fields(simple_rank_evaluate_driver,
ground_truth_pairs,
m = mocker.Mock()

def _eval_fn(actual, desired):
m()
assert isinstance(actual[0], Tuple)
assert isinstance(desired[0], Tuple)
return 1.0

simple_rank_evaluate_driver._exec_fn = _eval_fn
simple_rank_evaluate_driver._apply_all(ground_truth_pairs)
m.assert_called()


@pytest.mark.parametrize('fields', [('tags__id',), ('score__value',)])
def test_ranking_evaluate_runningavg_driver(runningavg_rank_evaluate_driver,
ground_truth_pairs):
runningavg_rank_evaluate_driver.attach(executor=PrecisionEvaluator(eval_at=2), runtime=None)
Expand Down
0