diff --git a/nats/js/api.py b/nats/js/api.py index e9db83e1..be0cd75c 100644 --- a/nats/js/api.py +++ b/nats/js/api.py @@ -13,7 +13,6 @@ # from __future__ import annotations - from dataclasses import dataclass, fields, replace from enum import Enum from typing import Any, Dict, Iterable, Iterator, List, Optional, TypeVar @@ -34,6 +33,9 @@ class Header(str, Enum): ROLLUP = "Nats-Rollup" STATUS = "Status" + TTL = "Nats-TTL" + MARKER_REASON = "Nats-Marker-Reason" + DEFAULT_PREFIX = "$JS.API" INBOX_PREFIX = b"_INBOX." @@ -308,9 +310,14 @@ class StreamConfig(Base): # Metadata are user defined string key/value pairs. metadata: Optional[Dict[str, str]] = None + # Allow per message ttl + allow_msg_ttl: bool = False + subject_delete_marker_ttl: Optional[int] = None # in seconds + @classmethod def from_response(cls, resp: Dict[str, Any]): cls._convert_nanoseconds(resp, "max_age") + cls._convert_nanoseconds(resp, "subject_delete_marker_ttl") cls._convert_nanoseconds(resp, "duplicate_window") cls._convert(resp, "placement", Placement) cls._convert(resp, "mirror", StreamSource) @@ -325,6 +332,9 @@ def as_dict(self) -> Dict[str, object]: self.duplicate_window ) result["max_age"] = self._to_nanoseconds(self.max_age) + result["subject_delete_marker_ttl"] = self._to_nanoseconds( + self.subject_delete_marker_ttl + ) if self.sources: result["sources"] = [src.as_dict() for src in self.sources] if self.compression and (self.compression != StoreCompression.NONE @@ -453,6 +463,13 @@ class ReplayPolicy(str, Enum): ORIGINAL = "original" +class PriorityPolicy(str, Enum): + """Group priority policy""" + + OVERFLOW = "overflow" + PINNED_CLIENT = "pinned_client" + + @dataclass class ConsumerConfig(Base): """Consumer configuration. @@ -500,6 +517,10 @@ class ConsumerConfig(Base): # Metadata are user defined string key/value pairs. metadata: Optional[Dict[str, str]] = None + # add priority groups + priority_groups: Optional[list[str]] = None + priority_policy: Optional[PriorityPolicy] = None + @classmethod def from_response(cls, resp: Dict[str, Any]): cls._convert_nanoseconds(resp, "ack_wait") diff --git a/nats/js/client.py b/nats/js/client.py index d26413c0..53a4d4ee 100644 --- a/nats/js/client.py +++ b/nats/js/client.py @@ -186,6 +186,7 @@ async def publish( timeout: Optional[float] = None, stream: Optional[str] = None, headers: Optional[Dict[str, Any]] = None, + ttl: Optional[int] = None ) -> api.PubAck: """ publish emits a new message to JetStream and waits for acknowledgement. @@ -197,6 +198,9 @@ async def publish( hdr = hdr or {} hdr[api.Header.EXPECTED_STREAM] = stream + if ttl is not None: + hdr = hdr or {} + hdr[api.Header.TTL] = str(ttl) try: msg = await self._nc.request( subject, @@ -219,6 +223,7 @@ async def publish_async( wait_stall: Optional[float] = None, stream: Optional[str] = None, headers: Optional[Dict] = None, + ttl: Optional[int] = None ) -> asyncio.Future[api.PubAck]: """ emits a new message to JetStream and returns a future that can be awaited for acknowledgement. @@ -233,6 +238,10 @@ async def publish_async( hdr = hdr or {} hdr[api.Header.EXPECTED_STREAM] = stream + if ttl is not None: + hdr = hdr or {} + hdr[api.Header.TTL] = str(ttl) + try: await asyncio.wait_for( self._publish_async_pending_semaphore.acquire(), @@ -1053,6 +1062,9 @@ async def fetch( batch: int = 1, timeout: Optional[float] = 5, heartbeat: Optional[float] = None, + group: Optional[str] = None, + min_pending: Optional[int] = None, + min_ack_pending: Optional[int] = None ) -> List[Msg]: """ fetch makes a request to JetStream to be delivered a set of messages. @@ -1060,6 +1072,7 @@ async def fetch( :param batch: Number of messages to fetch from server. :param timeout: Max duration of the fetch request before it expires. :param heartbeat: Idle Heartbeat interval in seconds for the fetch request. + :param group: If consumer has configured PriorityGroups, every Pull Request needs to provide it. :: @@ -1095,9 +1108,15 @@ async def main(): timeout * 1_000_000_000 ) - 100_000 if timeout else None if batch == 1: - msg = await self._fetch_one(expires, timeout, heartbeat) + msg = await self._fetch_one( + expires, timeout, heartbeat, group, min_pending, + min_ack_pending + ) return [msg] - msgs = await self._fetch_n(batch, expires, timeout, heartbeat) + msgs = await self._fetch_n( + batch, expires, timeout, heartbeat, group, min_pending, + min_ack_pending + ) return msgs async def _fetch_one( @@ -1105,6 +1124,9 @@ async def _fetch_one( expires: Optional[int], timeout: Optional[float], heartbeat: Optional[float] = None, + group: Optional[str] = None, + min_pending: Optional[int] = None, + min_ack_pending: Optional[int] = None ) -> Msg: queue = self._sub._pending_queue @@ -1133,6 +1155,15 @@ async def _fetch_one( heartbeat * 1_000_000_000 ) # to nanoseconds + if group: + next_req["group"] = group + + if min_pending: + next_req["min_pending"] = min_pending + + if min_ack_pending: + next_req["min_ack_pending"] = min_ack_pending + await self._nc.publish( self._nms, json.dumps(next_req).encode(), @@ -1183,6 +1214,9 @@ async def _fetch_n( expires: Optional[int], timeout: Optional[float], heartbeat: Optional[float] = None, + group: Optional[str] = None, + min_pending: Optional[int] = None, + min_ack_pending: Optional[int] = None ) -> List[Msg]: msgs = [] queue = self._sub._pending_queue @@ -1217,6 +1251,14 @@ async def _fetch_n( next_req["idle_heartbeat"] = int( heartbeat * 1_000_000_000 ) # to nanoseconds + if group: + next_req["group"] = str(group) + + if min_pending: + next_req["min_pending"] = min_pending + + if min_ack_pending: + next_req["min_ack_pending"] = min_ack_pending next_req["no_wait"] = True await self._nc.publish( self._nms, diff --git a/tests/test_js.py b/tests/test_js.py index 08a6377b..2d504254 100644 --- a/tests/test_js.py +++ b/tests/test_js.py @@ -13,7 +13,7 @@ from hashlib import sha256 import nats -import nats.js.api +from nats.js.api import Header import pytest from nats.aio.client import Client as NATS from nats.aio.errors import * @@ -132,6 +132,64 @@ async def test_publish_async(self): await nc.close() + @async_test + async def test_publish_per_message_ttl(self): + nc = NATS() + await nc.connect() + js = nc.jetstream() + + await js.add_stream( + name="QUUX", + subjects=["ququ"], + allow_msg_ttl=True, + subject_delete_marker_ttl=2, + ) + ack = await js.publish( + subject="ququ", payload=b"bar:1", stream="QUUX", ttl=2 + ) + assert ack.stream == "QUUX" + assert ack.seq == 1 + + info = await js.stream_info(name="QUUX") + assert info.state.messages == 1 + + message = await js.get_last_msg(stream_name="QUUX", subject="ququ") + ttl = message.headers.get(Header.TTL) + assert ttl == '2' + await nc.close() + + @async_test + async def test_async_publish_per_message_ttl(self): + nc = NATS() + await nc.connect() + js = nc.jetstream() + + await js.add_stream( + name="QUUX", + subjects=["ququ"], + allow_msg_ttl=True, + subject_delete_marker_ttl=2, + ) + + futures = [ + await js.publish_async(subject="ququ", payload=b"bar:1", ttl=2) + ] + + await js.publish_async_completed() + results = await asyncio.gather(*futures) + ack = results[0] + + assert ack.stream == "QUUX" + assert ack.seq == 1 + + info = await js.stream_info(name="QUUX") + assert info.state.messages == 1 + + message = await js.get_last_msg(stream_name="QUUX", subject="ququ") + ttl = message.headers.get(Header.TTL) + assert ttl == '2' + await nc.close() + class PullSubscribeTest(SingleJetStreamServerTestCase): @@ -1055,6 +1113,101 @@ async def test_fetch_heartbeats(self): await nc.close() + @async_test + async def test_pull_overflow(self): + nc = NATS() + await nc.connect() + js = nc.jetstream() + await js.add_stream(name="events", subjects=["events.a"]) + await js.add_consumer( + "events", + durable_name="a", + ack_policy="explicit", + max_waiting=512, + max_ack_pending=1024, + filter_subject="events.a", + priority_policy=api.PriorityPolicy.OVERFLOW.value, + priority_groups=["A"] + ) + sub = await js.pull_subscribe_bind( + "a", + stream="events", + ) + await js.publish("events.a", b"test") + + msgs = await sub.fetch(1, group="A") + for msg in msgs: + await msg.ack() + await nc.close() + + @async_test + async def test_pull_overflow_min_pending(self): + nc = NATS() + await nc.connect() + js = nc.jetstream() + await js.add_stream(name="events", subjects=["events.a"]) + await js.add_consumer( + "events", + durable_name="a", + ack_policy="explicit", + max_waiting=512, + max_ack_pending=1024, + filter_subject="events.a", + priority_policy=api.PriorityPolicy.OVERFLOW.value, + priority_groups=["A"] + ) + sub = await js.pull_subscribe_bind( + "a", + stream="events", + ) + for i in range(0, 5): + await js.publish("events.a", b"i:%d" % i) + + # because min pending > num_pending + with pytest.raises(asyncio.exceptions.CancelledError): + msgs = await sub.fetch(1, group="A", min_pending=10) + + for i in range(0, 20): + await js.publish("events.a", b"i:%d" % i) + + msgs = await sub.fetch(1, group="A", min_pending=10) + for msg in msgs: + await msg.ack() + await nc.close() + + @async_test + async def test_pull_overflow_min_ack_pending(self): + nc = NATS() + await nc.connect() + js = nc.jetstream() + await js.add_stream(name="events", subjects=["events.a"]) + await js.add_consumer( + "events", + durable_name="a", + ack_policy="explicit", + max_waiting=512, + max_ack_pending=1024, + filter_subject="events.a", + priority_policy=api.PriorityPolicy.OVERFLOW.value, + priority_groups=["A"] + ) + sub = await js.pull_subscribe_bind( + "a", + stream="events", + ) + for i in range(0, 5): + await js.publish("events.a", b"i:%d" % i) + + # because min_ack_pending > num_ack_pending + with pytest.raises(asyncio.exceptions.CancelledError): + await sub.fetch(1, group="A", min_ack_pending=10) + + for i in range(0, 20): + await js.publish("events.a", b"i:%d" % i) + await sub.fetch(10, group="A") + await sub.fetch(1, group="A", min_ack_pending=10) + await nc.close() + class JSMTest(SingleJetStreamServerTestCase):