8000 python/sdk: support direct put in etl flask webserver · NVIDIA/aistore@aaa45f1 · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Commit aaa45f1

Browse files
committed
python/sdk: support direct put in etl flask webserver
Signed-off-by: Tony Chen <a122774007@gmail.com>
1 parent f5a3fc6 commit aaa45f1

File tree

3 files changed

+78
-7
lines changed

3 files changed

+78
-7
lines changed

python/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ We structure this changelog in accordance with [Keep a Changelog](https://keepac
1111
### Added
1212

1313
- Add `direct put` support in the websocket endpoint of ETL fastapi web server.
14+
- Add `direct put` support in PUT and GET endpoints of ETL flask web server.
1415
- Introduce `glob_id` field on `JobSnapshot` and implement `unpack()` to decode (njoggers, nworkers, chan_full) from the packed integer.
1516

1617
### Changed

python/aistore/sdk/etl/webserver/flask_server.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import os
2-
from urllib.parse import unquote, quote
2+
from urllib.parse import unquote, quote, urlparse, urlunparse
33

44
import requests
55
from flask import Flask, request, Response, jsonify
66

77
from aistore.sdk.etl.webserver.base_etl_server import ETLServer
8+
from aistore.sdk.const import HEADER_NODE_URL
89

910

1011
class FlaskServer(ETLServer):
@@ -31,9 +32,19 @@ def _health(self):
3132

3233
def _handle_request(self, path):
3334
try:
35+
transformed = None
3436
if request.method == "GET":
35-
return self._handle_get(path)
36-
return self._handle_put(path)
37+
transformed = self._handle_get(path)
38+
else:
39+
transformed = self._handle_put(path)
40+
41+
direct_put_url = request.headers.get(HEADER_NODE_URL)
42+
if direct_put_url:
43+
resp = self._direct_put(direct_put_url, transformed)
44+
if resp is not None:
45+
return resp
46+
47+
return Response(response=transformed, content_type=self.get_mime_type())
3748
except FileNotFoundError:
3849
self.logger.error("File not found: %s", path)
3950
return (
@@ -64,17 +75,15 @@ def _handle_get(self, path):
6475
resp.raise_for_status()
6576
content = resp.content
6677

67-
transformed = self.transform(content, path)
68-
return Response(response=transformed, content_type=self.get_mime_type())
78+
return self.transform(content, path)
6979

7080
def _handle_put(self, path):
7181
if self.arg_type == "fqn":
7282
content = self._get_fqn_content(path)
7383
else:
7484
content = request.get_data()
7585

76-
transformed = self.transform(content, path)
77-
return Response(response=transformed, content_type=self.get_mime_type())
86+
return self.transform(content, path)
7887

7988
def _get_fqn_content(self, path: str) -> bytes:
8089
decoded_path = unquote(path)
@@ -83,6 +92,40 @@ def _get_fqn_content(self, path: str) -> bytes:
8392
with open(safe_path, "rb") as f:
8493
return f.read()
8594

95+
def _direct_put(self, direct_put_url: str, data: bytes):
96+
"""
97+
Sends the transformed object directly to the specified AIS node (`direct_put_url`),
98+
eliminating the additional network hop through the original target.
99+
Used only in bucket-to-bucket offline transforms.
100+
"""
101+
try:
102+
parsed_target = urlparse(direct_put_url)
103+
parsed_host = urlparse(self.host_target)
104+
url = urlunparse(
105+
parsed_host._replace(
106+
netloc=parsed_target.netloc,
107+
path=parsed_host.path + parsed_target.path,
108+
)
109+
)
110+
111+
resp = requests.put(url, data, timeout=None)
112+
if resp.status_code == 200:
113+
return Response(status=204)
114+
115+
error = resp.text()
116+
self.logger.warning(
117+
"Failed to deliver object to %s: HTTP %s, %s",
118+
direct_put_url,
119+
resp.status_code,
120+
error,
121+
)
122+
except Exception as e:
123+
self.logger.warning(
124+
"Exception during delivery to %s: %s", direct_put_url, e
125+
)
126+
127+
return None
128+
86129
# Example Gunicorn command to run this server:
87130
# command: ["gunicorn", "your_module:flask_app", "--bind", "0.0.0.0:8000", "--workers", "4"]
88131
def start(self):

python/tests/unit/sdk/test_etl_webserver.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,33 @@ def test_transform_get(self):
252252
assert "Content-Length" in response.headers
253253
assert int(response.headers["Content-Length"]) == len(response.data)
254254

255+
@unittest.skipIf(sys.version_info < (3, 9), "requires Python 3.9 or higher")
256+
def test_direct_put_delivery(self):
257+
self.etl_server.arg_type = ""
258+
path = "test/object"
259+
input_content = b"input data"
260+
headers = {HEADER_NODE_URL: "http://localhost:8080/ais/@/etl_dst/test/object"}
261+
262+
with patch("aistore.sdk.etl.webserver.flask_server.requests.put") as mock_put:
263+
# Mock the direct delivery response (simulate 200 OK)
264+
mock_put.return_value = MagicMock(status_code=200)
265+
response = self.client.put(f"/{path}", data=input_content, headers=headers)
266+
267+
self.assertEqual(response.status_code, 204)
268+
self.assertEqual(response.data, b"") # No content returned
269+
270+
with patch("aistore.sdk.etl.webserver.flask_server.requests.put") as mock_put:
271+
# Mock the direct delivery response (simulate 500 FAIL)
272+
mock_put.return_value = MagicMock(status_code=500)
273+
response = self.client.put(f"/{path}", data=input_content, headers=headers)
274+
275+
self.assertEqual(response.status_code, 200)
276+
self.assertEqual(
277+
response.data, b"flask: " + input_content
278+
) # Original content returned
279+
assert "Content-Length" in response.headers
280+
assert int(response.headers["Content-Length"]) == len(response.data)
281+
255282

256283
class TestBaseEnforcement(unittest.TestCase):
257284
def test_fastapi_server_without_target_url(self):

0 commit comments

Comments
 (0)
0