8000 fix: add async dry run method by deepankarm · Pull Request #4963 · jina-ai/serve · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

fix: add async dry run method #4963

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions jina/clients/grpc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from jina.clients.base.grpc import GRPCBaseClient
from jina.clients.mixin import AsyncPostMixin, HealthCheckMixin, PostMixin
from jina.clients.mixin import (
AsyncHealthCheckMixin,
AsyncPostMixin,
HealthCheckMixin,
PostMixin,
)


class GRPCClient(GRPCBaseClient, PostMixin, HealthCheckMixin):
Expand All @@ -23,7 +28,7 @@ class GRPCClient(GRPCBaseClient, PostMixin, HealthCheckMixin):
"""


class AsyncGRPCClient(GRPCBaseClient, AsyncPostMixin, HealthCheckMixin):
class AsyncGRPCClient(GRPCBaseClient, AsyncPostMixin, AsyncHealthCheckMixin):
"""
Asynchronous client connecting to a Gateway using gRPC protocol.

Expand Down
3 changes: 2 additions & 1 deletion jina/clients/http.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from jina.clients.base.http import HTTPBaseClient
from jina.clients.mixin import (
AsyncHealthCheckMixin,
AsyncMutateMixin,
AsyncPostMixin,
HealthCheckMixin,
Expand Down Expand Up @@ -30,7 +31,7 @@ class HTTPClient(HTTPBaseClient, PostMixin, MutateMixin, HealthCheckMixin):


class AsyncHTTPClient(
HTTPBaseClient, AsyncPostMixin, AsyncMutateMixin, HealthCheckMixin
HTTPBaseClient, AsyncPostMixin, AsyncMutateMixin, AsyncHealthCheckMixin
):
"""
Asynchronous client connecting to a Gateway using HTTP protocol.
Expand Down
13 changes: 13 additions & 0 deletions jina/clients/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

if TYPE_CHECKING:
from docarray import DocumentArray

from jina.clients.base import CallbackFnType, InputType
from jina.types.request import Response

Expand Down Expand Up @@ -100,6 +101,18 @@ def dry_run(self, **kwargs) -> bool:
return run_async(self.client._dry_run, **kwargs)


class AsyncHealthCheckMixin:
"""The Health check Mixin for Client and Flow to expose `dry_run` API"""

async def dry_run(self, **kwargs) -> bool:
"""Sends a dry run to the Flow to validate if the Flow is ready to receive requests

:param kwargs: potential kwargs received passed from the public interface
:return: boolean indicating the health/readiness of the Flow
"""
return await self.client._dry_run(**kwargs)


class PostMixin:
"""The Post Mixin class for Client and Flow"""

Expand Down
9 changes: 7 additions & 2 deletions jina/clients/websocket.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from jina.clients.base.websocket import WebSocketBaseClient
from jina.clients.mixin import AsyncPostMixin, HealthCheckMixin, PostMixin
from jina.clients.mixin import (
AsyncHealthCheckMixin,
AsyncPostMixin,
HealthCheckMixin,
PostMixin,
)


class WebSocketClient(WebSocketBaseClient, PostMixin, HealthCheckMixin):
Expand All @@ -23,7 +28,7 @@ class WebSocketClient(WebSocketBaseClient, PostMixin, HealthCheckMixin):
"""


class AsyncWebSocketClient(WebSocketBaseClient, AsyncPostMixin, HealthCheckMixin):
class AsyncWebSocketClient(WebSocketBaseClient, AsyncPostMixin, AsyncHealthCheckMixin):
"""
Asynchronous client connecting to a Gateway using WebSocket protocol.

Expand Down
49 changes: 37 additions & 12 deletions tests/integration/runtimes/test_gateway_dry_run.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
import asyncio
import json
import multiprocessing
import threading
import time
from collections import defaultdict

import pytest

from jina import Client, Document, Executor, requests
from jina.enums import PollingType
from jina import Client
from jina.parsers import set_gateway_parser, set_pod_parser
from jina.serve.runtimes.asyncio import AsyncNewLoopRuntime
from jina.serve.runtimes.gateway.grpc import GRPCGatewayRuntime
Expand Down Expand Up @@ -49,10 +43,7 @@ def _create_gateway_runtime(graph_description, pod_addresses, port, protocol='gr
runtime.run_forever()


@pytest.mark.parametrize('protocol', ['grpc', 'http', 'websocket'])
def test_dry_run_of_flow(port_generator, protocol):
worker_port = port_generator()
port = port_generator()
def _setup(worker_port, port, protocol):
graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}'
pod_addresses = f'{{"pod0": ["0.0.0.0:{worker_port}"]}}'

Expand All @@ -79,11 +70,19 @@ def test_dry_run_of_flow(port_generator, protocol):
ctrl_address=f'0.0.0.0:{port}',
ready_or_shutdown_event=multiprocessing.Event(),
)
return worker_process, gateway_process


@pytest.mark.parametrize('protocol', ['grpc', 'http', 'websocket'])
def test_dry_run_of_flow(port_generator, protocol):
worker_port = port_generator()
port = port_generator()
worker_process, gateway_process = _setup(worker_port, port, protocol)
# send requests to the gateway
c = Client(host='localhost', port=port, asyncio=True, protocol=protocol)
c = Client(host='localhost', port=port, protocol=protocol)
dry_run_alive = c.dry_run()

# _teardown(worker_process, gateway_process, dry_run_alive)
worker_process.terminate()
worker_process.join()

Expand All @@ -97,3 +96,29 @@ def test_dry_run_of_flow(port_generator, protocol):

assert gateway_process.exitcode == 0
assert worker_process.exitcode == 0


@pytest.mark.asyncio
@pytest.mark.parametrize('protocol', ['grpc', 'http', 'websocket'])
async def test_async_dry_run_of_flow(port_generator, protocol):
worker_port = port_generator()
port = port_generator()
worker_process, gateway_process = _setup(worker_port, port, protocol)
# send requests to the gateway
c = Client(host='localhost', asyncio=True, port=port, protocol=protocol)
dry_run_alive = await c.dry_run()

# _teardown(worker_process, gateway_process, dry_run_alive)
worker_process.terminate()
worker_process.join()

dry_run_worker_removed = await c.dry_run()

gateway_process.terminate()
gateway_process.join()

assert dry_run_alive
assert not dry_run_worker_removed

assert gateway_process.exitcode == 0
assert worker_process.exitcode == 0
0