8000 feat: Migrate NATS Queue to Rust (#669) by jthomson04 · Pull Request #961 · ai-dynamo/dynamo · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

feat: Migrate NATS Queue to Rust (#669) #961

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 8 additions & 78 deletions examples/llm/utils/nats_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@
from contextlib import asynccontextmanager
from typing import ClassVar, Optional

from nats.aio.client import Client as NATS
from nats.errors import Error as NatsError
from nats.js.client import JetStreamContext
from nats.js.errors import NotFoundError
from dynamo._core import NatsQueue


class NATSQueue:
Expand All @@ -34,15 +31,7 @@ def __init__(
nats_server: str = "nats://localhost:4222",
dequeue_timeout: float = 1,
):
self.nats_url = nats_server
self._nc: Optional[NATS] = None
self._js: Optional[JetStreamContext] = None
# TODO: check if this is needed
# Sanitize stream_name to remove path separators
self._stream_name = stream_name.replace("/", "_").replace("\\", "_")
self._subject = f"{self._stream_name}.*"
self.dequeue_timeout = dequeue_timeout
self._subscriber: Optional[JetStreamContext.PullSubscription] = None
self.nats_q = NatsQueue(stream_name, nats_server, dequeue_timeout)

@classmethod
@asynccontextmanager
Expand Down Expand Up @@ -81,79 +70,20 @@ async def shutdown(cls):
cls._instance = None

async def connect(self):
"""Establish connection and create stream if needed"""
try:
if self._nc is None:
self._nc = NATS()
await self._nc.connect(self.nats_url)
self._js = self._nc.jetstream()
# Check if stream exists, if not create it
try:
await self._js.stream_info(self._stream_name)
except NotFoundError:
await self._js.add_stream(
name=self._stream_name,
subjects=[self._subject],
# TODO: make these configurable and add guide to set these values
max_bytes=1073741824, # 1GB total storage limit
max_msgs=1000000, # 1 million messages limit
)
# Create persistent subscriber
self._subscriber = await self._js.pull_subscribe(
f"{self._stream_name}.queue", durable="worker-group"
)
except NatsError as e:
await self.close()
raise ConnectionError(f"Failed to connect to NATS: {e}")
await self.nats_q.connect()

async def ensure_connection(self):
"""Ensure we have an active connection"""
if self._nc is None or self._nc.is_closed:
await self.connect()
await self.nats_q.ensure_connection()

async def close(self):
"""Close the connection when done"""
if self._nc:
await self._nc.close()
self._nc = None
self._js = None
self._subscriber = None
await self.nats_q.close()

# TODO: is enqueue/dequeue_object a better name for a general queue?
async def enqueue_task(self, task_data: bytes) -> None:
"""
Enqueue a task using msgspec-encoded data
"""
await self.ensure_connection()
try:
await self._js.publish(f"{self._stream_name}.queue", task_data) # type: ignore
except NatsError as e:
raise RuntimeError(f"Failed to enqueue task: {e}")
await self.nats_q.enqueue_task(task_data)

async def dequeue_task(self) -> Optional[bytes]:
"""Dequeue and return a task as raw bytes, to be decoded with msgspec"""
await self.ensure_connection()
try:
msgs = await self._subscriber.fetch(1, timeout=self.dequeue_timeout) # type: ignore
if msgs:
msg = msgs[0]
await msg.ack()
return msg.data
return None
except asyncio.TimeoutError:
return None
except NatsError as e:
raise RuntimeError(f"Failed to dequeue task: {e}")
return await self.nats_q.dequeue_task()

async def get_queue_size(self) -> int:
"""Get the number of messages currently in the queue"""
await self.ensure_connection()
try:
# Get consumer info to get pending messages count
consumer_info = await self._js.consumer_info( # type: ignore
self._stream_name, "worker-group"
)
# Return number of pending messages (real-time queue size)
return consumer_info.num_pending
except NatsError as e:
raise RuntimeError(f"Failed to get queue size: {e}")
return await self.nats_q.get_queue_size()
1 change: 1 addition & 0 deletions lib/bindings/python/rust/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::kv::KvMetricsAggregator>()?;
m.add_class::<llm::kv::KvEventPublisher>()?;
m.add_class::<llm::kv::KvRecorder>()?;
m.add_class::<llm::nats::NatsQueue>()?;
m.add_class::<http::HttpService>()?;
m.add_class::<http::HttpError>()?;
m.add_class::<http::HttpAsyncEngine>()?;
Expand Down
1 change: 1 addition & 0 deletions lib/bindings/python/rust/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,5 @@ pub mod backend;
pub mod disagg_router;
pub mod kv;
pub mod model_card;
pub mod nats;
pub mod preprocessor;
103 changes: 103 additions & 0 deletions lib/bindings/python/rust/llm/nats.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use super::*;

#[pyclass(subclass)]
pub(crate) struct NatsQueue {
inner: Arc<Mutex<crate::rs::transports::nats::NatsQueue>>,
}

#[pymethods]
impl NatsQueue {
#[new]
#[pyo3(signature = (stream_name, nats_server, dequeue_timeout))]
fn new(stream_name: String, nats_server: String, dequeue_timeout: f64) -> PyResult<Self> {
let inner = Arc::new(Mutex::new(crate::rs::transports::nats::NatsQueue::new(
stream_name,
nats_server,
std::time::Duration::from_secs(dequeue_timeout as u64),
)));
Ok(Self { inner })
}

fn connect<'p>(&mut self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let queue = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
queue.lock().await.connect().await.map_err(to_pyerr)?;
Ok(())
})
}

fn ensure_connection<'p>(&mut self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let queue = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
queue
.lock()
.await
.ensure_connection()
.await
.map_err(to_pyerr)?;
Ok(())
})
}

fn close<'p>(&mut self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let queue = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
queue.lock().await.close().await.map_err(to_pyerr)?;
Ok(())
})
}

fn enqueue_task<'p>(
&mut self,
py: Python<'p>,
task_data: Py<PyBytes>,
) -> PyResult<Bound<'p, PyAny>> {
let bytes = task_data.as_bytes(py).to_vec();

let queue = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
queue
.lock()
.await
.enqueue_task(bytes.into())
.await
.map_err(to_pyerr)?;
Ok(())
})
}

fn dequeue_task<'p>(&mut self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let queue = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
Ok(queue
.lock()
.await
.dequeue_task()
.await
.map_err(to_pyerr)?
.map(|bytes| bytes.to_vec()))
})
}

fn get_queue_size<'p>(&mut self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let queue = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
queue.lock().await.get_queue_size().await.map_err(to_pyerr)
})
}
}
146 changes: 146 additions & 0 deletions lib/runtime/src/transports/nats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,152 @@ pub fn url_to_bucket_and_key(url: &Url) -> anyhow::Result<(Stri 9E88 ng, String)> {
Ok((bucket.to_string(), key.to_string()))
}

/// A queue implementation using NATS JetStream
pub struct NatsQueue {
/// The name of the stream to use for the queue
stream_name: String,
/// The NATS server URL
nats_server: String,
/// Timeout for dequeue operations in seconds
dequeue_timeout: time::Duration,
/// The NATS client
client: Option<Client>,
/// The subject pattern used for this queue
subject: String,
/// The subscriber for pull-based consumption
subscriber: Option<jetstream::consumer::PullConsumer>,
}

impl NatsQueue {
/// Create a new NatsQueue with the given configuration
pub fn new(stream_name: String, nats_server: String, dequeue_timeout: time::Duration) -> Self {
// Sanitize stream name to remove path separators (like in Python version)
let sanitized_stream_name = stream_name.replace(['/', '\\'], "_");

let subject = format!("{}.*", sanitized_stream_name);

Self {
stream_name: sanitized_stream_name,
nats_server,
dequeue_timeout,
client: None,
subject,
subscriber: None,
}
}

/// Connect to the NATS server and set up the stream and consumer
pub async fn connect(&mut self) -> Result<()> {
if self.client.is_none() {
// Create a new client
let client_options = Client::builder().server(self.nats_server.clone()).build()?;

let client = client_options.connect().await?;

// Check if stream exists, if not create it
let streams = client.list_streams().await?;
if !streams.contains(&self.stream_name) {
log::debug!("Creating NATS stream {}", self.stream_name);
let stream_config = jetstream::stream::Config {
name: self.stream_name.clone(),
subjects: vec![self.subject.clone()],
..Default::default()
};
client.jetstream().create_stream(stream_config).await?;
}

// Create persistent subscriber
let consumer_config = jetstream::consumer::pull::Config {
durable_name: Some("worker-group".to_string()),
..Default::default()
};

let stream = client.jetstream().get_stream(&self.stream_name).await?;
let subscriber = stream.create_consumer(consumer_config).await?;

self.subscriber = Some(subscriber);
self.client = Some(client);
}

Ok(())
}

/// Ensure we have an active connection
pub async fn ensure_connection(&mut self) -> Result<()> {
if self.client.is_none() {
self.connect().await?;
}
Ok(())
}

/// Close the connection when done
pub async fn close(&mut self) -> Result<()> {
self.subscriber = None;
self.client = None;
Ok(())
}

/// Enqueue a task using the provided data
pub async fn enqueue_task(&mut self, task_data: Bytes) -> Result<()> {
self.ensure_connection().await?;

if let Some(client) = &self.client {
let subject = format!("{}.queue", self.stream_name);
client.jetstream().publish(subject, task_data).await?;
Ok(())
} else {
Err(anyhow::anyhow!("Client not connected"))
}
}

/// Dequeue and return a task as raw bytes
pub async fn dequeue_task(&mut self) -> Result<Option<Bytes>> {
self.ensure_connection().await?;

if let Some(subscriber) = &self.subscriber {
let mut batch = subscriber
.fetch()
.expires(self.dequeue_timeout)
.max_messages(1)
.messages()
.await?;

if let Some(message) = batch.next().await {
let message =
message.map_err(|e| anyhow::anyhow!("Failed to get message: {}", e))?;
message
.ack()
.await
.map_err(|e| anyhow::anyhow!("Failed to ack message: {}", e))?;
Ok(Some(message.payload.clone()))
} else {
Ok(None)
}
} else {
Err(anyhow::anyhow!("Subscriber not initialized"))
}
}

/// Get the number of messages currently in the queue
pub async fn get_queue_size(&mut self) -> Result<u64> {
self.ensure_connection().await?;

if let Some(client) = &self.client {
// Get consumer info to get pending messages count
let stream = client.jetstream().get_stream(&self.stream_name).await?;
let mut consumer: jetstream::consumer::PullConsumer = stream
.get_consumer("worker-group")
.await
.map_err(|e| anyhow::anyhow!("Failed to get consumer: {}", e))?;
let info = consumer.info().await?;

Ok(info.num_pending)
} else {
Err(anyhow::anyhow!("Client not connected"))
}
}
}

#[cfg(test)]
mod tests {

Expand Down
Loading
0