8000 fix(core): fix request/response missmatch while prefetching by jacobowitz · Pull Request #2843 · jina-ai/serve · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

fix(core): fix request/response missmatch while prefetching #2843

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jul 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions jina/peapods/runtimes/gateway/grpc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
import os
from typing import AsyncGenerator

import grpc

from ..prefetch import PrefetchMixin
from ..prefetch import PrefetchCaller
from ...zmq.asyncio import AsyncNewLoopRuntime
from ....zmq import AsyncZmqlet
from .....logging.logger import JinaLogger
from .....proto import jina_pb2_grpc

__all__ = ['GRPCRuntime']

from .....types.message import Message

class GRPCPrefetchCall(PrefetchMixin, jina_pb2_grpc.JinaRPCServicer):

class GRPCPrefetchCall(jina_pb2_grpc.JinaRPCServicer):
"""JinaRPCServicer """

def __init__(self, args, zmqlet):
super().__init__()
self.args = args
self.zmqlet = zmqlet
self.name = args.name or self.__class__.__name__
self.logger = JinaLogger(self.name, **vars(args))
self._servicer = PrefetchCaller(args, zmqlet)
self.Call = self._servicer.send
self.close = self._servicer.close


class GRPCRuntime(AsyncNewLoopRuntime):
Expand All @@ -42,16 +43,16 @@ async def async_setup(self):
]
)
self.zmqlet = AsyncZmqlet(self.args, logger=self.logger)
jina_pb2_grpc.add_JinaRPCServicer_to_server(
GRPCPrefetchCall(self.args, self.zmqlet), self.server
)
self._prefetcher = GRPCPrefetchCall(self.args, self.zmqlet)
jina_pb2_grpc.add_JinaRPCServicer_to_server(self._prefetcher, self.server)
bind_addr = f'{self.args.host}:{self.args.port_expose}'
self.server.add_insecure_port(bind_addr)
await self.server.start()

async def async_cancel(self):
"""The async method to stop server."""
await self.server.stop(0)
await self._prefetcher.close()

async def async_run_forever(self):
"""The async running of server."""
Expand Down
5 changes: 3 additions & 2 deletions jina/peapods/runtimes/gateway/http/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def get_fastapi_app(args: 'argparse.Namespace', logger: 'JinaLogger'):
servicer = PrefetchCaller(args, zmqlet)

@app.on_event('shutdown')
def _shutdown():
async def _shutdown():
await servicer.close()
zmqlet.close()

openapi_tags = []
Expand Down Expand Up @@ -197,7 +198,7 @@ async def _get_singleton_result(req_iter) -> Dict:
:param req_iter: request iterator, with length of 1
:return: the first result from the request iterator
"""
async for k in servicer.Call(request_iterator=req_iter):
async for k in servicer.send(request_iterator=req_iter):
return MessageToDict(
k, including_default_value_fields=True, use_integers_for_enums=True
) # DO NOT customize other serialization here. Scheme is handled by Pydantic in `models.py`
Expand Down
88 changes: 60 additions & 28 deletions jina/peapods/runtimes/gateway/prefetch.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,65 @@
import argparse
import asyncio
from abc import ABC
from typing import AsyncGenerator
from asyncio import Future
from typing import AsyncGenerator, Dict

from ....helper import typename
from ....helper import typename, get_or_reuse_loop
from ....logging.logger import JinaLogger
from ....types.message import Message

__all__ = ['PrefetchCaller', 'PrefetchMixin']
__all__ = ['PrefetchCaller']

if False:
from ...zmq import AsyncZmqlet


class PrefetchMixin(ABC):
"""JinaRPCServicer """
class PrefetchCaller:
"""An async zmq request sender to be used in the Gateway"""

async def Call(self, request_iterator, *args) -> AsyncGenerator[None, Message]:
def __init__(self, args: argparse.Namespace, zmqlet: 'AsyncZmqlet'):
"""
:param args: args from CLI
:param zmqlet: zeromq object
"""
self.args = args
self.zmqlet = zmqlet
self.name = args.name or self.__class__.__name__

self.logger = JinaLogger(self.name, **vars(args))
self._message_buffer: Dict[str, Future[Message]] = dict()
self._receive_task = get_or_reuse_loop().create_task(self._receive())

async def _receive(self):
try:
while True:
message = await self.zmqlet.recv_message(callback=lambda x: x.response)
# during shutdown the socket will return None
if message is None:
break

if message.request_id in self._message_buffer:
future = self._message_buffer.pop(message.request_id)
future.set_result(message)
else:
self.logger.warning(
f'Discarding unexpected message with request id {message.request_id}'
)
except asyncio.CancelledError:
raise
finally:
for future in self._message_buffer.values():
future.cancel(
'PrefetchCaller closed, all outstanding requests canceled'
)
self._message_buffer.clear()

async def close(self):
"""
Stop receiving messages
"""
self._receive_task.cancel()

async def send(self, request_iterator, *args) -> AsyncGenerator[None, Message]:
"""
Async call to receive Requests and build them into Messages.

Expand All @@ -28,6 +71,11 @@ async def Call(self, request_iterator, *args) -> AsyncGenerator[None, Message]:
self.zmqlet: 'AsyncZmqlet'
self.logger: JinaLogger

if self._receive_task.done():
raise RuntimeError(
'PrefetchCaller receive task not running, can not send messages'
)

async def prefetch_req(num_req, fetch_to):
"""
Fetch and send request.
Expand All @@ -46,16 +94,16 @@ async def prefetch_req(num_req, fetch_to):
raise TypeError(
f'{typename(request_iterator)} does not have `__anext__` or `__next__`'
)

future = get_or_reuse_loop().create_future()
self._message_buffer[next_request.request_id] = future
asyncio.create_task(
self.zmqlet.send_me 6D47 ssage(
Message(None, next_request, 'gateway', **vars(self.args))
)
)
fetch_to.append(
asyncio.create_task(
self.zmqlet.recv_message(callback=lambda x: x.response)
)
)

fetch_to.append(future)
except (StopIteration, StopAsyncIteration):
return True
return False
Expand Down Expand Up @@ -95,19 +143,3 @@ async def prefetch_req(num_req, fetch_to):
# this list dries, clear it and feed it with on_recv_task
prefetch_task.clear()
prefetch_task = [j for j in onrecv_task]


class PrefetchCaller(PrefetchMixin):
"""An async zmq request sender to be used in the Gateway"""

def __init__(self, args: argparse.Namespace, zmqlet: 'AsyncZmqlet'):
"""

:param args: args from CLI
:param zmqlet: zeromq object
"""
super().__init__()
self.args = args
self.zmqlet = zmqlet
self.name = args.name or self.__class__.__name__
self.logger = JinaLogger(self.name, **vars(args))
5 changes: 3 additions & 2 deletions jina/peapods/runtimes/gateway/websocket/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def disconnect(self, websocket: WebSocket):
servicer = PrefetchCaller(args, zmqlet)

@app.on_event('shutdown')
def _shutdown():
async def _shutdown():
await servicer.close()
zmqlet.close()

@app.websocket('/')
Expand All @@ -55,7 +56,7 @@ async def req_iter():
yield Request(data)

try:
async for msg in servicer.Call(request_iterator=req_iter()):
async for msg in servicer.send(request_iterator=req_iter()):
await websocket.send_bytes(msg.binary_str())
except WebSocketDisconnect:
manager.disconnect(websocket)
Expand Down
Empty file.
52 changes: 52 additions & 0 deletions tests/integration/concurrent_clients/test_concurrent_clients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest

from jina import Flow, Executor, Client, requests, DocumentArray, Document
from jina.types.request import Response
import threading
import random
import time
from functools import partial


class MyExecutor(Executor):
@requests(on='/ping')
def ping(self, docs: DocumentArray, **kwargs):
time.sleep(0.1 * random.random())


@pytest.mark.parametrize('protocal', ['http', 'grpc'])
@pytest.mark.parametrize('parallel', [10])
@pytest.mark.parametrize('polling', ['ANY', 'ALL'])
@pytest.mark.parametrize('prefetch', [1, 10])
@pytest.mark.parametrize('concurrent', [15])
def test_concurrent_clients(concurrent, protocal, parallel, polling, prefetch, reraise):
def pong(peer_hash, resp: Response):
with reraise:
for d in resp.docs:
assert d.text == peer_hash

def peer_client(port, protocal, peer_hash):
c = Client(protocol=protocal, port_expose=port)
for _ in range(5):
c.post(
'/ping', Document(text=peer_hash), on_done=lambda r: pong(peer_hash, r)
)

f = Flow(protocol=protocal, prefetch=prefetch).add(
uses=MyExecutor, parallel=parallel, polling=polling
)

with f:
port_expose = f.port_expose

thread_pool = []
for peer_id in range(concurrent):
t = threading.Thread(
target=partial(peer_client, port_expose, protocal, str(peer_id)),
daemon=True,
)
t.start()
thread_pool.append(t)

for t in thread_pool:
t.join()
14 changes: 10 additions & 4 deletions tests/unit/peapods/runtimes/asyncio/rest/test_app.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import pytest
from fastapi.testclient import TestClient

from jina.logging.logger import JinaLogger
from jina.peapods.runtimes.gateway.http import get_fastapi_app
from jina.parsers import set_gateway_parser
import pytest
from jina.peapods.runtimes.gateway.http import get_fastapi_app


@pytest.mark.parametrize('p', [['--default-swagger-ui'], []])
def test_custom_swagger(p):
args = set_gateway_parser().parse_args(p)
logger = JinaLogger('')
app = get_fastapi_app(args, logger)
assert any('/docs' in r.path for r in app.routes)
assert any('/openapi.json' in r.path for r in app.routes)
# The TestClient is needed here as a context manager to generate the shutdown event correctly
# otherwise the app can hang as it is not cleaned up correctly
# see https://fastapi.tiangolo.com/advanced/testing-events/
with TestClient(app) as client:
assert any('/docs' in r.path for r in app.routes)
assert any('/openapi.json' in r.path for r in app.routes)
49 changes: 49 additions & 0 deletions tests/unit/peapods/runtimes/gateway/test_prefetch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import asyncio
from asyncio import Future

import pytest

from jina.helper import ArgNamespace, random_identity
from jina.parsers import set_gateway_parser
from jina.peapods.runtimes.gateway.prefetch import PrefetchCaller
from jina.proto import jina_pb2


def _generate_request():
req = jina_pb2.RequestProto()
req.request_id = random_identity()
req.data.docs.add()
return req


class ZmqletMock:
def __init__(self):
self.sent_future = Future()
self.received_event = asyncio.Event()

async def recv_message(self, **kwargs):
msg = await self.sent_future
self.sent_future = Future()
self.received_event.set()
return msg

async def send_message(self, message):
self.sent_future.set_result(_generate_request())
await self.received_event.wait()
self.sent_future.set_result(message.response)


@pytest.mark.asyncio
async def test_concurrent_requests():
args = ArgNamespace.kwargs2namespace({}, set_gateway_parser())
mock_zmqlet = ZmqletMock()
servicer = PrefetchCaller(args, mock_zmqlet)

request = _generate_request()

response = servicer.send(iter([request]))

async for r in response:
assert r == request

await servicer.close()
0