8000 feat(types): support lambda function in parallel mixin by hanxiao · Pull Request #4022 · jina-ai/serve · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

feat(types): support lambda function in parallel mixin #4022

8000
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions docs/fundamentals/document/documentarray-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ map-thread ... map-thread takes 10 seconds (10.28s)
foo-loop ... foo-loop takes 18 seconds (18.52s)
```

One can see big improvement with `.map()`.
One can see a significant speedup with `.map()`.

```{admonition} When to choose process or thread backend?
:class: important
Expand All @@ -875,18 +875,12 @@ If you only modify elements in-place, and do not need return values, you can wri

```python
da = DocumentArray(...)
list(da.map(func))
da.apply(func)
```

This follows the same convention as you using Python built-in `map()`.

You can also use `.apply()` which always returns a `DocumentArray`.

````




## Sampling

`DocumentArray` provides a `.sample` function that samples `k` elements without replacement. It accepts two parameters, `k`
Expand Down
2 changes: 1 addition & 1 deletion extra-requirements.txt
10000
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,6 @@ pytest-lazy-fixture: test
datasets: cicd
av: cicd
trimesh: cicd
paddlepaddle: cicd
paddlepaddle==2.2.0: cicd
onnx: cicd
onnxruntime: cicd
2 changes: 1 addition & 1 deletion jina/resources/extra-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,6 @@ pytest-lazy-fixture: test
datasets: cicd
av: cicd
trimesh: cicd
paddlepaddle: cicd
paddlepaddle==2.2.0: cicd
onnx: cicd
onnxruntime: cicd
48 changes: 39 additions & 9 deletions jina/types/arrays/mixins/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def match(
only_id: bool = False,
use_scipy: bool = False,
device: str = 'cpu',
num_worker: Optional[int] = None,
**kwargs,
) -> None:
"""Compute embedding based nearest neighbour in `another` for each Document in `self`,
Expand All @@ -54,9 +55,8 @@ def match(
the min distance will be rescaled to `a`, the max distance will be rescaled to `b`
all values will be rescaled into range `[a, b]`.
:param metric_name: if provided, then match result will be marked with this string.
:param batch_size: if provided, then `darray` is loaded in chunks of, at most, batch_size elements. This option
will be slower but more memory efficient. Specialy indicated if `darray` is a big
DocumentArrayMemmap.
:param batch_size: if provided, then ``darray`` is loaded in batches, where each of them is at most ``batch_size``
elements. When `darray` is big, this can significantly speedup the computation.
:param traversal_ldarray: DEPRECATED. if set, then matching is applied along the `traversal_path` of the
left-hand ``DocumentArray``.
:param traversal_rdarray: DEPRECATED. if set, then matching is applied along the `traversal_path` of the
Expand All @@ -68,6 +68,11 @@ def match(
:param use_scipy: if set, use ``scipy`` as the computation backend. Note, ``scipy`` does not support distance
on sparse matrix.
:param device: the computational device for ``.match()``, can be either `cpu` or `cuda`.
:param num_worker: the number of parallel workers. If not given, then the number of CPUs in the system will be used.

.. note::
This argument is only effective when ``batch_size`` is set.

:param kwargs: other kwargs.
"""
if limit is not None:
Expand Down Expand Up @@ -136,7 +141,7 @@ def match(

if batch_size:
dist, idx = lhv._match_online(
rhv, cdist, _limit, normalization, metric_name, batch_size
rhv, cdist, _limit, normalization, metric_name, batch_size, num_worker
)
else:
dist, idx = lhv._match(rhv, cdist, _limit, normalization, metric_name)
Expand Down Expand Up @@ -204,7 +209,14 @@ def _match(self, darray, cdist, limit, normalization, metric_name):
return dist, idx

def _match_online(
self, darray, cdist, limit, normalization, metric_name, batch_size
self,
darray,
cdist,
limit,
normalization,
metric_name,
batch_size,
num_worker,
):
"""
Computes the matches between self and `darray` loading `darray` into main memory in chunks of size `batch_size`.
Expand All @@ -218,6 +230,7 @@ def _match_online(
all values will be rescaled into range `[a, b]`.
:param batch_size: length of the chunks loaded into memory from darray.
:param metric_name: if provided, then match result will be marked with this string.
:param num_worker: the number of parallel workers. If not given, then the number of CPUs in the system will be used.
:return: distances and indices
"""

Expand All @@ -227,17 +240,34 @@ def _match_online(
idx = 0
top_dists = np.inf * np.ones((n_x, limit))
top_inds = np.zeros((n_x, limit), dtype=int)
for ld in darray.batch(batch_size=batch_size):
y_batch = ld.embeddings

def _get_dist(da: 'DocumentArray'):
y_batch = da.embeddings

distances = cdist(x_mat, y_batch, metric_name)
dists, inds = top_k(distances, limit, descending=False)

if isinstance(normalization, (tuple, list)) and normalization is not None:
dists = minmax_normalize(dists, normalization)

inds = idx + inds
idx += y_batch.shape[0]
return dists, inds, y_batch.shape[0]

if num_worker is None or num_worker > 1:
# notice that all most all computations (regardless the framework) are conducted in C
# hence there is no worry on Python GIL and the backend can be safely put to `thread` to
# save unnecessary data passing. This in fact gives a huge boost on the performance.
_gen = darray.map_batch(
_get_dist,
batch_size=batch_size,
backend='thread',
num_worker=num_worker,
)
else:
_gen = (_get_dist(b) for b in darray.batch(batch_size=batch_size))

for (dists, inds, _bs) in _gen:
inds += idx
idx += _bs
top_dists, top_inds = update_rows_x_mat_best(
top_dists, top_inds, dists, inds, limit
)
Expand Down
26 changes: 25 additions & 1 deletion jina/types/arrays/mixins/parallel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import sys
from types import LambdaType
from typing import Callable, TYPE_CHECKING, Generator, Optional, overload, TypeVar

if TYPE_CHECKING:
from ....helper import T
from ....helper import T, random_identity
from ...document import Document
from .... import DocumentArray

Expand Down Expand Up @@ -76,6 +78,8 @@ def map(
:param num_worker: the number of parallel workers. If not given, then the number of CPUs in the system will be used.
:yield: anything return from ``func``
"""
if _is_lambda_or_local_function(func) and backend == 'process':
func = _globalize_lambda_function(func)
with _get_pool(backend, num_worker) as p:
for x in p.imap(func, self):
yield x
Expand Down Expand Up @@ -154,6 +158,9 @@ def map_batch(
:param num_worker: the number of parallel workers. If not given, then the number of CPUs in the system will be used.
:yield: anything return from ``func``
"""

if _is_lambda_or_local_function(func) and backend == 'process':
func = _globalize_lambda_function(func)
with _get_pool(backend, num_worker) as p:
for x in p.imap(func, self.batch(batch_size=batch_size, shuffle=shuffle)):
yield x
Expand All @@ -172,3 +179,20 @@ def _get_pool(backend, num_worker):
raise ValueError(
f'`backend` must be either `process` or `thread`, receiving {backend}'
)


def _is_lambda_or_local_function(func):
return (isinstance(func, LambdaType) and func.__name__ == '<lambda>') or (
'<locals>' in func.__qualname__
)


def _globalize_lambda_function(func):
def result(*args, **kwargs):
return func(*args, **kwargs)

from ....helper import random_identity

result.__name__ = result.__qualname__ = random_identity()
setattr(sys.modules[result.__module__], result.__name__, result)
return result
10 changes: 7 additions & 3 deletions jina/types/document/mixins/sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def match(
exclude_self: bool = False,
only_id: bool = False,
use_scipy: bool = False,
num_worker: Optional[int] = None,
) -> 'T':
"""Matching the current Document against a set F438 of Documents.

Expand All @@ -45,14 +46,17 @@ def match(
the min distance will be rescaled to `a`, the max distance will be rescaled to `b`
all values will be rescaled into range `[a, b]`.
:param metric_name: if provided, then match result will be marked with this string.
:param batch_size: if provided, then `darray` is loaded in chunks of, at most, batch_size elements. This option
will be slower but more memory efficient. Specialy indicated if `darray` is a big
DocumentArrayMemmap.
:param batch_size: if provided, then ``darray`` is loaded in batches, where each of them is at most ``batch_size``
elements. When `darray` is big, this can significantly speedup the computation.
:param exclude_self: if set, Documents in ``darray`` with same ``id`` as the left-hand values will not be
considered as matches.
:param only_id: if set, then returning matches will only contain ``id``
:param use_scipy: if set, use ``scipy`` as the computation backend. Note, ``scipy`` does not support distance
on sparse matrix.
:param num_worker: the number of parallel workers. If not given, then the number of CPUs in the system will be used.

.. note::
This argument is only effective when ``batch_size`` is set.
"""
...

Expand Down
15 changes: 13 additions & 2 deletions tests/unit/types/arrays/mixins/test_parallel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from jina import DocumentArrayMemmap, DocumentArray, Document
from jina import DocumentArray, Document, DocumentArrayMemmap


def foo(d: Document):
Expand All @@ -20,7 +20,7 @@ def foo_batch(da: DocumentArray):

@pytest.mark.parametrize('da_cls', [DocumentArray, DocumentArrayMemmap])
@pytest.mark.parametrize('backend', ['process', 'thread'])
@pytest.mark.parametrize('num_worker', [1, 2])
@pytest.mark.parametrize('num_worker', [1, 2, None])
def test_parallel_map(pytestconfig, da_cls, backend, num_worker):
da = da_cls.from_files(f'{pytestconfig.rootdir}/docs/**/*.png')[:10]

Expand Down Expand Up @@ -71,3 +71,14 @@ def test_parallel_map_batch(pytestconfig, da_cls, backend, num_worker, b_size):

da_new = da.apply_batch(foo_batch, batch_size=b_size)
assert da_new.blobs.shape == (len(da_new), 3, 222, 222)


@pytest.mark.parametrize('da_cls', [DocumentArray, DocumentArrayMemmap])
def test_map_lambda(pytestconfig, da_cls):
da = da_cls.from_files(f'{pytestconfig.rootdir}/docs/**/*.png')[:10]

for d in da:
assert d.blob is None

for d in da.map(lambda x: x.load_uri_to_image_blob()):
assert d.blob is not None
0