8000 python/etl: Implement `FastAPIServer` base class for async ETL proces… · NVIDIA/aistore@e1e4935 · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Commit e1e4935

Browse files
python/etl: Implement FastAPIServer base class for async ETL processing
- prev: 7625a96 Signed-off-by: Abhishek Gaikwad <gaikwadabhishek1997@gmail.com>
1 parent 504f218 commit e1e4935

File tree

6 files changed

+310
-13
lines changed

6 files changed

+310
-13
lines changed

python/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ We structure this changelog in accordance with [Keep a Changelog](https://keepac
1515
- Add abstract base class `ETLServer` defining the common interface for ETL servers.
1616
- Includes abstract methods: `transform`, and `start`.
1717
- Implement `HTTPMultiThreadedServer`, a multi-threaded server based on `BaseHTTPRequestHandler`.
18+
- Implement `FastAPIServer` base class for async ETL processing
1819

1920
### Changed
2021
- `props` accessor ensures object properties are refreshed via HEAD request on every access.

python/aistore/common_requirements

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,8 @@ pylint>=3.2.7
1414
pytest>=8.3.5
1515
typing-extensions>=4.12.2
1616
webdataset>=0.2.100
17-
tenacity>=9.0.0
17+
tenacity>=9.0.0
18+
fastapi>=0.109.1
19+
httpx>=0.28.0
20+
aiofiles>=23.2.1
21+
uvicorn>=0.32.0

python/aistore/sdk/etl/webserver/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,7 @@
33
#
44

55
from aistore.sdk.etl.webserver.http_multi_threaded_server import HTTPMultiThreadedServer
6+
from aistore.sdk.etl.webserver.fastapi_server import FastAPIServer
67

7-
__all__ = ["HTTPMultiThreadedServer"]
8+
9+
__all__ = ["HTTPMultiThreadedServer", "FastAPIServer"]
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import os
2+
import asyncio
3+
from urllib.parse import unquote, quote
4+
from typing import Optional, List
5+
6+
from fastapi import (
7+
FastAPI,
8+
Request,
9+
HTTPException,
10+
Response,
11+
WebSocket,
12+
WebSocketDisconnect,
13+
)
14+
import httpx
15+
import aiofiles
16+
import uvicorn
17+
18+
from aistore.sdk.etl.webserver.base_etl_server import ETLServer
19+
20+
21+
class FastAPIServer(ETLServer):
22+
"""
23+
FastAPI server implementation for ETL transformations.
24+
Utilizes async/await and threading for optimal request handling.
25+
"""
26+
27+
def __init__(self, host: str = "0.0.0.0", port: int = 80):
28+
super().__init__()
29+
self.host = host
30+
self.port = port
31+
self.app = FastAPI()
32+
self.client: Optional[httpx.AsyncClient] = None
33+
self.active_connections: List[WebSocket] = []
34+
self._setup_app()
35+
36+
def _setup_app(self):
37+
"""Configure FastAPI routes and event handlers."""
38+
self.app.state.etl_server = self
39+
self.app.add_event_handler("startup", self.startup_event)
40+
self.app.add_event_handler("shutdown", self.shutdown_event)
41+
42+
@self.app.get("/health")
43+
async def health_check():
44+
return Response(content=b"Running")
45+
46+
@self.app.get("/{path:path}")
47+
async def handle_get(path: str, request: Request):
48+
return await self._handle_request(path, request, is_get=True)
49+
50+
@self.app.put("/{path:path}")
51+
async def handle_put(path: str, request: Request):
52+
return await self._handle_request(path, request, is_get=False)
53+
54+
@self.app.websocket("/ws")
55+
async def websocket_endpoint(websocket: WebSocket):
56+
self.logger.debug(
57+
"New WebSocket connection attempt from: %s", websocket.client
58+
)
59+
60+
try:
61+
await websocket.accept()
62+
self.logger.debug(
63+
"WebSocket connection established: %s", websocket.client
64+
)
65+
self.active_connections.append(websocket)
66+
67+
while True:
68+
data = await websocket.receive_bytes()
69+
self.logger.info("Received message of length: %d", len(data))
70+
transformed = await asyncio.to_thread(self.transform, data, "")
71+
await websocket.send_bytes(transformed)
72+
73+
except WebSocketDisconnect:
74+
self.logger.warning("WebSocket disconnected: %s", websocket.client)
75+
self.active_connections.remove(websocket)
76+
77+
except Exception as e:
78+
self.logger.error(
79+
"Unexpected WebSocket error from %s: %s", websocket.client, e
80+
)
81+
self.active_connections.remove(websocket)
82+
await websocket.close()
83+
84+
async def startup_event(self):
85+
"""Initialize resources on server startup."""
86+
self.client = httpx.AsyncClient(timeout=None)
87+
self.logger.info("Server starting up")
88+
89+
async def shutdown_event(self):
90+
"""Cleanup resources on server shutdown."""
91+
await self.client.aclose()
92+
self.logger.info("Server shutting down")
93+
94+
async def _handle_request(self, path: str, request: Request, is_get: bool):
95+
"""Unified request handler for GET/PUT operations."""
96+
self.logger.info(
97+
"Processing %s request for path: %s", "GET" if is_get else "PUT", path
98+
)
99+
100+
try:
101+
if self.arg_type == "fqn":
102+
content = await self._get_fqn_content(path)
103+
else:
104+
content = (
105+
await self._get_network_content(path)
106+
if is_get
107+
else await request.body()
108+
)
109+
110+
transformed = await asyncio.to_thread(self.transform, content, path)
111+
return self._build_response(transformed, self.get_mime_type())
112+
113+
except FileNotFoundError as exc:
114+
self.logger.error("File not found: %s", path)
115+
raise HTTPException(
116+
404,
117+
detail=(
118+
f"Local file not found: {path}. "
119+
"This typically indicates the ETL container was not started with the correct volume mounts."
120+
"Please verify your ETL specification includes the necessary mount paths."
121+
),
122+
) from exc
123+
except httpx.HTTPStatusError as e:
124+
self.logger.warning(
125+
"Target responded with error: %s", e.response.status_code
126+
)
127+
raise HTTPException(
128+
e.response.status_code, detail="Target request failed"
129+
) from e
130+
except httpx.RequestError as e:
131+
self.logger.error("Network error: %s", str(e))
132+
raise HTTPException(502, detail=f"Network error: {str(e)}") from e
133+
except Exception as e:
134+
self.logger.exception("Critical error during processing")
135+
raise HTTPException(500, detail=f"Processing error: {str(e)}") from e
136+
137+
async def _get_fqn_content(self, path: str) -> bytes:
138+
"""Safely read local file content with path normalization."""
139+
decoded_path = unquote(path)
140+
safe_path = os.path.normpath(os.path.join("/", decoded_path.lstrip("/")))
141+
self.logger.info("Reading local file: %s", safe_path)
142+
143+
async with aiofiles.open(safe_path, "rb") as f:
144+
return await f.read()
145+
146+
async def _get_network_content(self, path: str) -> bytes:
147+
"""Retrieve content from AIS target with async HTTP client."""
148+
obj_path = quote(path, safe="@")
149+
target_url = f"{self.host_target}/{obj_path}"
150+
self.logger.info("Forwarding to target: %s", target_url)
151+
152+
response = await self.client.get(target_url)
153+
response.raise_for_status()
154+
return response.content
155+
156+
def _build_response(self, content: bytes, mime_type: str) -> Response:
157+
"""Construct standardized response with appropriate headers."""
158+
return Response(
159+
content=content,
160+
media_type=mime_type,
161+
headers={"Content-Length": str(len(content))},
162+
)
163+
164+
def start(self):
165+
"""Start the server with production-optimized settings."""
166+
uvicorn.run(self.app, host=self.host, port=self.port)

python/pyproject.toml

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,24 @@ keywords = [
6363
"ETL",
6464
"Petascale",
6565
"High Performance",
66-
"Lightweight Object Storage"
66+
"Lightweight Object Storage",
6767
]
6868

6969
[project.optional-dependencies]
70-
pytorch = ["torch", "torchdata", "webdataset>=0.2.100"]
71-
botocore = ["wrapt"]
70+
pytorch = [
71+
"torch",
72+
"torchdata",
73+
"webdataset>=0.2.100",
74+
]
75+
botocore = [
76+
"wrapt",
77+
]
78+
etl = [
79+
"fastapi>=0.109.1",
80+
"httpx>=0.28.0",
81+
"aiofiles>=23.2.1",
82+
"uvicorn>=0.32.0",
83+
]
7284

7385
[project.urls]
7486
"Homepage" = "https://aistore.nvidia.com"

python/tests/unit/sdk/test_etl_webserver.py

Lines changed: 120 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import os
2+
import sys
23
import unittest
34
from http.server import BaseHTTPRequestHandler
4-
from unittest.mock import MagicMock
5+
from unittest.mock import MagicMock, patch, AsyncMock
6+
from fastapi.testclient import TestClient
57
from aistore.sdk.etl.webserver import HTTPMultiThreadedServer
8+
from aistore.sdk.etl.webserver.fastapi_server import FastAPIServer
69

710

811
class DummyETLServer(HTTPMultiThreadedServer):
9-
"""
10-
Dummy ETL server for testing transform and MIME type override.
11-
"""
12+
"""Dummy ETL server for testing transform and MIME type override."""
1213

1314
def transform(self, data: bytes, path: str) -> bytes:
1415
return data.upper()
@@ -17,12 +18,10 @@ def get_mime_type(self) -> str:
1718
return "text/caps"
1819

1920

20-
# pylint: disable=super-init-not-called
2121
class DummyRequestHandler(BaseHTTPRequestHandler):
22-
"""
23-
Fake handler that mocks HTTP methods for isolated testing of _set_headers().
24-
"""
22+
"""Fake handler that mocks HTTP methods for isolated testing of `set_headers()`."""
2523

24+
# pylint: disable=super-init-not-called
2625
def __init__(self):
2726
# Don't call super().__init__(), just mock the necessary parts
2827
self.send_response = MagicMock()
@@ -38,6 +37,14 @@ def set_headers(self, status_code: int = 200):
3837
self.end_headers()
3938

4039

40+
class DummyFastAPIServer(FastAPIServer):
41+
def transform(self, data: bytes, path: str) -> bytes:
42+
return data[::-1] # Simple reverse transform for testing purposes
43+
44+
def get_mime_type(self) -> str:
45+
return "application/test"
46+
47+
4148
class TestETLServerLogic(unittest.TestCase):
4249
def setUp(self):
4350
os.environ["AIS_TARGET_URL"] = "http://localhost:8080"
@@ -64,3 +71,108 @@ def test_set_headers_calls_expected_methods(self):
6471
handler.send_response.assert_called_once_with(202)
6572
handler.send_header.assert_called_once_with("Content-Type", "application/test")
6673
handler.end_headers.assert_called_once()
74+
75+
76+
class TestFastAPIServer(unittest.IsolatedAsyncioTestCase):
77+
def setUp(self):
78+
os.environ["AIS_TARGET_URL"] = "http://localhost:8080"
79+
self.etl_server = DummyFastAPIServer()
80+
self.client = TestClient(self.etl_server.app)
81+
82+
def test_health_check(self):
83+
response = self.client.get("/health")
84+
self.assertEqual(response.status_code, 200)
85+
self.assertEqual(response.content, b"Running")
86+
87+
# pylint: disable=protected-access
88+
async def test_get_network_content(self):
89+
path = "test/path"
90+
fake_content = b"fake data"
91+
92+
with patch.object(self.etl_server, "client", AsyncMock()) as mock_client:
93+
mock_response = AsyncMock()
94+
mock_response.content = fake_content
95+
mock_response.raise_for_status = MagicMock()
96+
97+
mock_client.get.return_value = mock_response
98+
99+
result = await self.etl_server._get_network_content(path)
100+
101+
self.assertEqual(result, fake_content)
102+
mock_client.get.assert_called_once()
103+
104+
@unittest.skipIf(sys.version_info < (3, 9), "requires Python 3.9 or higher")
105+
async def test_handle_get_request(self):
106+
self.etl_server.arg_type = ""
107+
path = "test/object"
108+
original_content = b"original data"
109+
transformed_content = original_content[::-1]
110+
111+
with patch.object(
112+
self.etl_server,
113+
"_get_network_content",
114+
AsyncMock(return_value=original_content),
115+
):
116+
response = self.client.get(f"/{path}")
117+
118+
self.assertEqual(response.status_code, 200)
119+
self.assertEqual(response.content, transformed_content)
120+
121+
@unittest.skipIf(sys.version_info < (3, 9), "requires Python 3.9 or higher")
122+
async def test_handle_put_request(self):
123+
self.etl_server.arg_type = ""
124+
path = "test/object"
125+
input_content = b"input data"
126+
transformed_content = input_content[::-1]
127+
128+
response = self.client.put(f"/{path}", content=input_content)
129+
130+
self.assertEqual(response.status_code, 200)
131+
self.assertEqual(response.content, transformed_content)
132+
133+
134+
class TestBaseEnforcement(unittest.TestCase):
135+
def test_fastapi_server_without_target_url(self):
136+
if "AIS_TARGET_URL" in os.environ:
137+
del os.environ["AIS_TARGET_URL"]
138+
139+
class MinimalFastAPIServer(FastAPIServer):
140+
def transform(self, data: bytes, path: str) -> bytes:
141+
return data
142+
143+
with self.assertRaises(EnvironmentError) as context:
144+
MinimalFastAPIServer()
145+
self.assertIn("AIS_TARGET_URL", str(context.exception))
146+
147+
def test_http_server_server_without_target_url(self):
148+
if "AIS_TARGET_URL" in os.environ:
149+
del os.environ["AIS_TARGET_URL"]
150+
151+
class MinimalHTTPServer(HTTPMultiThreadedServer):
152+
def transform(self, data: bytes, path: str) -> bytes:
153+
return data
154+
155+
with self.assertRaises(EnvironmentError) as context:
156+
MinimalHTTPServer()
157+
self.assertIn("AIS_TARGET_URL", str(context.exception))
158+
159+
def test_http_multithreaded_server_without_transform(self):
160+
if "AIS_TARGET_URL" in os.environ:
161+
del os.environ["AIS_TARGET_URL"]
162+
163+
class IncompleteETLServer(HTTPMultiThreadedServer):
164+
pass
165+
166+
with self.assertRaises(TypeError) as context:
167+
IncompleteETLServer() # pylint: disable=abstract-class-instantiated
168+
self.assertIn("Can't instantiate abstract class", str(context.exception))
169+
170+
def test_fastapi_server_without_transform(self):
171+
os.environ["AIS_TARGET_URL"] = "http://localhost"
172+
173+
class IncompleteFastAPIServer(FastAPIServer):
174+
pass
175+
176+
with self.assertRaises(TypeError) as context:
177+
IncompleteFastAPIServer() # pylint: disable=abstract-class-instantiated
178+
self.assertIn("Can't instantiate abstract class", str(context.exception))

0 commit comments

Comments
 (0)
0