10000 etl: support direct put in websocket communicator · NVIDIA/aistore@5b46ac3 · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Commit 5b46ac3

Browse files
committed
etl: support direct put in websocket communicator
* sends two consecutive WebSocket messages: 1. A `TextMessage` containing the direct PUT address. 2. A `BinaryMessage` containing the object content. * etl server is expected to consume them in the same order and treat them as logically linked. Signed-off-by: Tony Chen <a122774007@gmail.com>
1 parent 7b46f0c commit 5b46ac3

File tree

6 files changed

+123
-29
lines changed

6 files changed

+123
-29
lines changed

ais/test/etl_test.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -473,9 +473,19 @@ func TestETLBucketTransformParallel(t *testing.T) {
473473
parallel = 3
474474
)
475475

476-
for _, commType := range []string{etl.WebSocket, etl.Hpush, etl.Hpull} {
477-
t.Run("etl_bucket_transform_parallel__"+commType, func(t *testing.T) {
478-
_ = tetl.InitSpec(t, baseParams, transformer, commType)
476+
tests := []struct {
477+
commType string
478+
onlyLong bool
479+
}{
480+
{commType: etl.Hpush},
481+
{commType: etl.Hpull},
482+
{commType: etl.WebSocket, onlyLong: true},
483+
}
484+
485+
for _, test := range tests {
486+
t.Run("etl_bucket_transform_parallel__"+test.commType, func(t *testing.T) {
487+
tools.CheckSkip(t, &tools.SkipTestArgs{Long: test.onlyLong})
488+
_ = tetl.InitSpec(t, baseParams, transformer, test.commType)
479489
t.Cleanup(func() { tetl.StopAndDeleteETL(t, baseParams, transformer) })
480490

481491
wg := &sync.WaitGroup{}
@@ -528,7 +538,7 @@ func TestETLAnyToAnyBucket(t *testing.T) {
528538
bcktests = []testBucketConfig{{false, false, false}}
529539

530540
tests = []testObjConfig{
531-
{transformer: tetl.Echo, comm: etl.WebSocket},
541+
{transformer: tetl.Echo, comm: etl.WebSocket, onlyLong: true},
532542
{transformer: tetl.Echo, comm: etl.Hpull},
533543
{transformer: tetl.Echo, comm: etl.Hpush},
534544
}

ais/tgtobj.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1526,7 +1526,7 @@ func (coi *coi) do(t *target, dm *bundle.DM, lom *core.LOM) (res xs.CoiRes) {
15261526
daddr := cos.JoinPath(tsi.URL(cmn.NetIntraData), url.PathEscape(cos.UnsafeS(uname))) // use escaped URL to simplify parsing on the ETL side
15271527
resp := coi.GetROC(lom, coi.LatestVer, coi.Sync, daddr)
15281528
// skip t2t send if encounter error during GetROC, or returns empty reader (etl delivered case)
1529-
if resp.Err != nil || resp.R == nil {
1529+
if resp.Err != nil || resp.R == nil || resp.Ecode == http.StatusNoContent {
15301530
return xs.CoiRes{Err: resp.Err, Ecode: resp.Ecode}
15311531
}
15321532
coi.OAH = resp.OAH

ext/etl/websocket_comm.go

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,9 @@ type (
6868
}
6969

7070
rwpair struct {
71-
r io.ReadCloser
72-
w *io.PipeWriter
71+
r io.ReadCloser
72+
w *io.PipeWriter
73+
daddr string
7374
}
7475
)
7576

@@ -154,14 +155,14 @@ func (ws *webSocketComm) Stop() {
154155
// wsConnCtx //
155156
///////////////
156157

157-
func (wctx *wsConnCtx) Transform(lom *core.LOM, latestVer, sync bool, _ string) core.ReadResp {
158+
func (wctx *wsConnCtx) Transform(lom *core.LOM, latestVer, sync bool, daddr string) core.ReadResp {
158159
srcResp := lom.GetROC(latestVer, sync)
159160
if srcResp.Err != nil {
160161
return srcResp
161162
}
162163
pr, pw := io.Pipe()
163164

164-
wctx.workCh <- rwpair{srcResp.R, pw}
165+
wctx.workCh <- rwpair{srcResp.R, pw, daddr}
165166
if l, c := len(wctx.workCh), cap(wctx.workCh); l > c/2 {
166167
runtime.Gosched() // poor man's throttle
167168
if l == c {
@@ -176,7 +177,7 @@ func (wctx *wsConnCtx) Transform(lom *core.LOM, latestVer, sync bool, _ string)
176177
R: cos.NopOpener(pr),
177178
OAH: srcResp.OAH, // TODO: estimate the post-transformed Lsize for stats
178179
Err: nil,
179-
Ecode: http.StatusOK,
180+
Ecode: http.StatusNoContent, // promise to deliver
180181
Remote: srcResp.Remote,
181182
}
182183
}
@@ -195,10 +196,6 @@ func (wctx *wsConnCtx) Finish(errCause error) error {
195196
}
196197

197198
func (wctx *wsConnCtx) readLoop() error {
198-
wctx.conn.SetPingHandler(func(msg string) error {
199-
return wctx.conn.WriteMessage(websocket.PongMessage, []byte(msg))
200-
})
201-
202199
buf, slab := core.T.PageMM().Alloc()
203200
defer slab.Free(buf)
204201

@@ -211,20 +208,23 @@ func (wctx *wsConnCtx) readLoop() error {
211208
case <-wctx.etlxctn.ChanAbort():
212209
return nil
213210
default:
214-
_, r, err := wctx.conn.NextReader()
211+
ty, r, err := wctx.conn.NextReader()
215212
if err != nil {
216213
err = fmt.Errorf("error on reading message from %s: %w", wctx.name, err)
217214
wctx.txctn.AddErr(err)
218-
return err
219215
}
220216
writer := <-wctx.writerCh
221217

222-
// TODO: do not abort on soft error
218+
// direct put success (TextMessage ack from ETL server)
219+
if ty == websocket.TextMessage {
220+
cos.Close(writer)
221+
continue
222+
}
223+
223224
if _, err = cos.CopyBuffer(writer, r, buf); err != nil {
224225
cos.Close(writer)
225226
err = fmt.Errorf("error on copying message from %s: %w", wctx.name, err)
226227
wctx.txctn.AddErr(err)
227-
return err
228228
}
229229
cos.Close(writer)
230230
}
@@ -244,26 +244,36 @@ func (wctx *wsConnCtx) writeLoop() error {
244244
case <-wctx.etlxctn.ChanAbort():
245245
return nil
246246
case rw := <-wctx.workCh:
247+
// Leverages the fact that WebSocket preserves message order and boundaries.
248+
// Sends two consecutive WebSocket messages:
249+
// 1. A TextMessage containing the direct PUT address
250+
// 2. A BinaryMessage containing the object content
251+
//
252+
// The ETL server is expected to consume them in the same order and treat them as logically linked.
253+
254+
// 1. send direct put address
255+
err := wctx.conn.WriteMessage(websocket.TextMessage, []byte(rw.daddr))
256+
if err != nil {
257+
err = fmt.Errorf("error on writing direct put address %s: %w", wctx.name, err)
258+
wctx.txctn.AddErr(err)
259+
}
260+
261+
// 2. send object content
247262
writer, err := wctx.conn.NextWriter(websocket.BinaryMessage)
248263
if err != nil {
249264
err = fmt.Errorf("error on getting next writter from %s: %w", wctx.name, err)
250265
wctx.txctn.AddErr(err)
251-
return err
252266
}
253-
254-
// TODO: do not abort on soft error
255267
if _, err := cos.CopyBuffer(writer, rw.r, buf); err != nil {
256268
cos.Close(rw.r)
257269
cos.Close(writer)
258270
err = fmt.Errorf("error on getting next writter from %s: %w", wctx.name, err)
259271
wctx.txctn.AddErr(err)
260-
return err
261272
}
262273
cos.Close(rw.r)
263-
// Leverages the fact that WebSocket preserves message order and boundaries.
274+
264275
// For each object sent to the ETL via `wctx.workCh`, we reserve a corresponding slot in `wctx.writerCh`
265276
// for its writer. This guarantees that the correct writer is matched with the response when it arrives.
266-
267277
wctx.writerCh <- rw.w
268278
if l, c := len(wctx.writerCh), cap(wctx.writerCh); l > c/2 {
269279
runtime.Gosched() // poor man's throttle

python/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ We structure this changelog in accordance with [Keep a Changelog](https://keepac
1313
### Changed
1414

1515
- Adjust error handling logic to accommodate updated object naming (see [here](https://github.com/NVIDIA/aistore/blob/main/docs/unicode.md)).
16+
- Add `direct put` support in the websocket endpoint of ETL fastapi web server.
1617

1718
### Removed
1819

python/aistore/sdk/etl/webserver/fastapi_server.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,29 @@ async def websocket_endpoint(websocket: WebSocket):
6666
self.active_connections.append(websocket)
6767

6868
while True:
69-
data = await websocket.receive_bytes()
70-
self.logger.info("Received message of length: %d", len(data))
69+
delivery_target_url = None
70+
msg = await websocket.receive()
71+
if "text" in msg:
72+
delivery_target_url = msg["text"]
73+
data = await websocket.receive_bytes()
74+
elif "bytes" in msg:
75+
data = msg["bytes"]
76+
else:
77+
self.logger.warning("Received unknown message format: %s", msg)
78+
continue
79+
80+
self.logger.debug("Received message of length: %d", len(data))
81+
7182
transformed = await asyncio.to_thread(self.transform, data, "")
83+
84+
if delivery_target_url:
85+
response = await self._direct_put(
86+
delivery_target_url, transformed
87+
)
88+
if response:
89+
await websocket.send_text("direct put success")
90+
continue
91+
7292
await websocket.send_bytes(transformed)
7393

7494
except WebSocketDisconnect:

python/tests/unit/sdk/test_etl_webserver.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,10 @@ async def test_direct_put_delivery(self):
152152
input_content = b"input data"
153153

154154
# Mock the direct delivery response (simulate 200 OK)
155-
mock_response = AsyncMock()
156-
mock_response.status_code = 200
155+
mock_response_success = AsyncMock()
156+
mock_response_success.status_code = 200
157157
self.etl_server.client = AsyncMock()
158-
self.etl_server.client.put.return_value = mock_response
158+
self.etl_server.client.put.return_value = mock_response_success
159159

160160
headers = {HEADER_NODE_URL: "http://localhost:8080/ais/@/etl_dst/test/object"}
161161
response = self.client.put(f"/{path}", content=input_content, headers=headers)
@@ -164,6 +164,59 @@ async def test_direct_put_delivery(self):
164164
self.assertEqual(response.content, b"") # No content returned
165165
self.etl_server.client.put.assert_awaited_once()
166166

167+
# Mock the direct delivery response (simulate 500 FAIL)
168+
mock_response_fail = AsyncMock()
169+
mock_response_fail.status_code = 500
170+
self.etl_server.client = AsyncMock()
171+
self.etl_server.client.put.return_value = mock_response_fail
172+
173+
headers = {HEADER_NODE_URL: "http://localhost:8080/ais/@/etl_dst/test/object"}
174+
response = self.client.put(f"/{path}", content=input_content, headers=headers)
175+
176+
self.assertEqual(response.status_code, 200)
177+
self.assertEqual(
178+
response.content, input_content[::-1]
179+
) # Original content returned
180+
self.etl_server.client.put.assert_awaited_once()
181+
182+
@unittest.skipIf(sys.version_info < (3, 9), "requires Python 3.9 or higher")
183+
async def test_websocket(self):
184+
with self.client.websocket_connect("/ws") as websocket:
185+
original_data = b"abcdef"
186+
websocket.send_text("") # No delivery target URL
187+
websocket.send_bytes(original_data)
188+
result = websocket.receive_bytes()
189+
self.assertEqual(result, original_data[::-1])
190+
191+
@unittest.skipIf(sys.version_info < (3, 9), "requires Python 3.9 or higher")
192+
async def test_websocket_with_direct_put(self):
193+
input_data = b"testdata"
194+
# Mock the direct delivery response (simulate 200 OK)
195+
with patch.object(self.etl_server, "client", new=AsyncMock()) as mock_client:
196+
mock_resp = AsyncMock()
197+
mock_resp.status_code = 200
198+
mock_client.put.return_value = mock_resp
199+
200+
with self.client.websocket_connect("/ws") as websocket:
201+
websocket.send_text("http://localhost:8080/ais/@/etl_dst/final")
202+
websocket.send_bytes(input_data)
203+
result = websocket.receive_text()
204+
self.assertEqual(result, "direct put success")
205+
mock_client.put.assert_awaited_once()
206+
207+
# Mock the direct delivery response (simulate 500 FAIL)
208+
with patch.object(self.etl_server, "client", new=AsyncMock()) as mock_client:
209+
mock_resp = AsyncMock()
210+
mock_resp.status_code = 500
211+
mock_client.put.return_value = mock_resp
212+
213+
with self.client.websocket_connect("/ws") as websocket:
214+
websocket.send_text("http://localhost:8080/ais/@/etl_dst/final")
215+
websocket.send_bytes(input_data)
216+
result = websocket.receive_bytes()
217+
self.assertEqual(result, input_data[::-1])
218+
mock_client.put.assert_awaited_once()
219+
167220

168221
class TestFlaskServer(unittest.TestCase):
169222
def setUp(self):

0 commit comments

Comments
 (0)
0