Skip to content

Resolve #682 #683

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion nats/js/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#

from __future__ import annotations

from dataclasses import dataclass, fields, replace
from enum import Enum
from typing import Any, Dict, Iterable, Iterator, List, Optional, TypeVar
Expand All @@ -34,6 +33,9 @@ class Header(str, Enum):
ROLLUP = "Nats-Rollup"
STATUS = "Status"

TTL = "Nats-TTL"
MARKER_REASON = "Nats-Marker-Reason"


DEFAULT_PREFIX = "$JS.API"
INBOX_PREFIX = b"_INBOX."
Expand Down Expand Up @@ -308,9 +310,14 @@ class StreamConfig(Base):
# Metadata are user defined string key/value pairs.
metadata: Optional[Dict[str, str]] = None

# Allow per message ttl
allow_msg_ttl: bool = False
subject_delete_marker_ttl: Optional[int] = None # in seconds

@classmethod
def from_response(cls, resp: Dict[str, Any]):
cls._convert_nanoseconds(resp, "max_age")
cls._convert_nanoseconds(resp, "subject_delete_marker_ttl")
cls._convert_nanoseconds(resp, "duplicate_window")
cls._convert(resp, "placement", Placement)
cls._convert(resp, "mirror", StreamSource)
Expand All @@ -325,6 +332,9 @@ def as_dict(self) -> Dict[str, object]:
self.duplicate_window
)
result["max_age"] = self._to_nanoseconds(self.max_age)
result["subject_delete_marker_ttl"] = self._to_nanoseconds(
self.subject_delete_marker_ttl
)
if self.sources:
result["sources"] = [src.as_dict() for src in self.sources]
if self.compression and (self.compression != StoreCompression.NONE
Expand Down Expand Up @@ -453,6 +463,13 @@ class ReplayPolicy(str, Enum):
ORIGINAL = "original"


class PriorityPolicy(str, Enum):
"""Group priority policy"""

OVERFLOW = "overflow"
PINNED_CLIENT = "pinned_client"


@dataclass
class ConsumerConfig(Base):
"""Consumer configuration.
Expand Down Expand Up @@ -500,6 +517,10 @@ class ConsumerConfig(Base):
# Metadata are user defined string key/value pairs.
metadata: Optional[Dict[str, str]] = None

# add priority groups
priority_groups: Optional[list[str]] = None
priority_policy: Optional[PriorityPolicy] = None

@classmethod
def from_response(cls, resp: Dict[str, Any]):
cls._convert_nanoseconds(resp, "ack_wait")
Expand Down
46 changes: 44 additions & 2 deletions nats/js/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ async def publish(
timeout: Optional[float] = None,
stream: Optional[str] = None,
headers: Optional[Dict[str, Any]] = None,
ttl: Optional[int] = None
) -> api.PubAck:
"""
publish emits a new message to JetStream and waits for acknowledgement.
Expand All @@ -197,6 +198,9 @@ async def publish(
hdr = hdr or {}
hdr[api.Header.EXPECTED_STREAM] = stream

if ttl is not None:
hdr = hdr or {}
hdr[api.Header.TTL] = str(ttl)
try:
msg = await self._nc.request(
subject,
Expand All @@ -219,6 +223,7 @@ async def publish_async(
wait_stall: Optional[float] = None,
stream: Optional[str] = None,
headers: Optional[Dict] = None,
ttl: Optional[int] = None
) -> asyncio.Future[api.PubAck]:
"""
emits a new message to JetStream and returns a future that can be awaited for acknowledgement.
Expand All @@ -233,6 +238,10 @@ async def publish_async(
hdr = hdr or {}
hdr[api.Header.EXPECTED_STREAM] = stream

if ttl is not None:
hdr = hdr or {}
hdr[api.Header.TTL] = str(ttl)

try:
await asyncio.wait_for(
self._publish_async_pending_semaphore.acquire(),
Expand Down Expand Up @@ -1053,13 +1062,17 @@ async def fetch(
batch: int = 1,
timeout: Optional[float] = 5,
heartbeat: Optional[float] = None,
group: Optional[str] = None,
min_pending: Optional[int] = None,
min_ack_pending: Optional[int] = None
) -> List[Msg]:
"""
fetch makes a request to JetStream to be delivered a set of messages.

:param batch: Number of messages to fetch from server.
:param timeout: Max duration of the fetch request before it expires.
:param heartbeat: Idle Heartbeat interval in seconds for the fetch request.
:param group: If consumer has configured PriorityGroups, every Pull Request needs to provide it.

::

Expand Down Expand Up @@ -1095,16 +1108,25 @@ async def main():
timeout * 1_000_000_000
) - 100_000 if timeout else None
if batch == 1:
msg = await self._fetch_one(expires, timeout, heartbeat)
msg = await self._fetch_one(
expires, timeout, heartbeat, group, min_pending,
min_ack_pending
)
return [msg]
msgs = await self._fetch_n(batch, expires, timeout, heartbeat)
msgs = await self._fetch_n(
batch, expires, timeout, heartbeat, group, min_pending,
min_ack_pending
)
return msgs

async def _fetch_one(
self,
expires: Optional[int],
timeout: Optional[float],
heartbeat: Optional[float] = None,
group: Optional[str] = None,
min_pending: Optional[int] = None,
min_ack_pending: Optional[int] = None
) -> Msg:
queue = self._sub._pending_queue

Expand Down Expand Up @@ -1133,6 +1155,15 @@ async def _fetch_one(
heartbeat * 1_000_000_000
) # to nanoseconds

if group:
next_req["group"] = group

if min_pending:
next_req["min_pending"] = min_pending

if min_ack_pending:
next_req["min_ack_pending"] = min_ack_pending

await self._nc.publish(
self._nms,
json.dumps(next_req).encode(),
Expand Down Expand Up @@ -1183,6 +1214,9 @@ async def _fetch_n(
expires: Optional[int],
timeout: Optional[float],
heartbeat: Optional[float] = None,
group: Optional[str] = None,
min_pending: Optional[int] = None,
min_ack_pending: Optional[int] = None
) -> List[Msg]:
msgs = []
queue = self._sub._pending_queue
Expand Down Expand Up @@ -1217,6 +1251,14 @@ async def _fetch_n(
next_req["idle_heartbeat"] = int(
heartbeat * 1_000_000_000
) # to nanoseconds
if group:
next_req["group"] = str(group)

if min_pending:
next_req["min_pending"] = min_pending

if min_ack_pending:
next_req["min_ack_pending"] = min_ack_pending
next_req["no_wait"] = True
await self._nc.publish(
self._nms,
Expand Down
155 changes: 154 additions & 1 deletion tests/test_js.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from hashlib import sha256

import nats
import nats.js.api
from nats.js.api import Header
import pytest
from nats.aio.client import Client as NATS
from nats.aio.errors import *
Expand Down Expand Up @@ -132,6 +132,64 @@ async def test_publish_async(self):

await nc.close()

@async_test
async def test_publish_per_message_ttl(self):
nc = NATS()
await nc.connect()
js = nc.jetstream()

await js.add_stream(
name="QUUX",
subjects=["ququ"],
allow_msg_ttl=True,
subject_delete_marker_ttl=2,
)
ack = await js.publish(
subject="ququ", payload=b"bar:1", stream="QUUX", ttl=2
)
assert ack.stream == "QUUX"
assert ack.seq == 1

info = await js.stream_info(name="QUUX")
assert info.state.messages == 1

message = await js.get_last_msg(stream_name="QUUX", subject="ququ")
ttl = message.headers.get(Header.TTL)
assert ttl == '2'
await nc.close()

@async_test
async def test_async_publish_per_message_ttl(self):
nc = NATS()
await nc.connect()
js = nc.jetstream()

await js.add_stream(
name="QUUX",
subjects=["ququ"],
allow_msg_ttl=True,
subject_delete_marker_ttl=2,
)

futures = [
await js.publish_async(subject="ququ", payload=b"bar:1", ttl=2)
]

await js.publish_async_completed()
results = await asyncio.gather(*futures)
ack = results[0]

assert ack.stream == "QUUX"
assert ack.seq == 1

info = await js.stream_info(name="QUUX")
assert info.state.messages == 1

message = await js.get_last_msg(stream_name="QUUX", subject="ququ")
ttl = message.headers.get(Header.TTL)
assert ttl == '2'
await nc.close()


class PullSubscribeTest(SingleJetStreamServerTestCase):

Expand Down Expand Up @@ -1055,6 +1113,101 @@ async def test_fetch_heartbeats(self):

await nc.close()

@async_test
async def test_pull_overflow(self):
nc = NATS()
await nc.connect()
js = nc.jetstream()
await js.add_stream(name="events", subjects=["events.a"])
await js.add_consumer(
"events",
durable_name="a",
ack_policy="explicit",
max_waiting=512,
max_ack_pending=1024,
filter_subject="events.a",
priority_policy=api.PriorityPolicy.OVERFLOW.value,
priority_groups=["A"]
)
sub = await js.pull_subscribe_bind(
"a",
stream="events",
)
await js.publish("events.a", b"test")

msgs = await sub.fetch(1, group="A")
for msg in msgs:
await msg.ack()
await nc.close()

@async_test
async def test_pull_overflow_min_pending(self):
nc = NATS()
await nc.connect()
js = nc.jetstream()
await js.add_stream(name="events", subjects=["events.a"])
await js.add_consumer(
"events",
durable_name="a",
ack_policy="explicit",
max_waiting=512,
max_ack_pending=1024,
filter_subject="events.a",
priority_policy=api.PriorityPolicy.OVERFLOW.value,
priority_groups=["A"]
)
sub = await js.pull_subscribe_bind(
"a",
stream="events",
)
for i in range(0, 5):
await js.publish("events.a", b"i:%d" % i)

# because min pending > num_pending
with pytest.raises(asyncio.exceptions.CancelledError):
msgs = await sub.fetch(1, group="A", min_pending=10)

for i in range(0, 20):
await js.publish("events.a", b"i:%d" % i)

msgs = await sub.fetch(1, group="A", min_pending=10)
for msg in msgs:
await msg.ack()
await nc.close()

@async_test
async def test_pull_overflow_min_ack_pending(self):
nc = NATS()
await nc.connect()
js = nc.jetstream()
await js.add_stream(name="events", subjects=["events.a"])
await js.add_consumer(
"events",
durable_name="a",
ack_policy="explicit",
max_waiting=512,
max_ack_pending=1024,
filter_subject="events.a",
priority_policy=api.PriorityPolicy.OVERFLOW.value,
priority_groups=["A"]
)
sub = await js.pull_subscribe_bind(
"a",
stream="events",
)
for i in range(0, 5):
await js.publish("events.a", b"i:%d" % i)

# because min_ack_pending > num_ack_pending
with pytest.raises(asyncio.exceptions.CancelledError):
await sub.fetch(1, group="A", min_ack_pending=10)

for i in range(0, 20):
await js.publish("events.a", b"i:%d" % i)
await sub.fetch(10, group="A")
await sub.fetch(1, group="A", min_ack_pending=10)
await nc.close()


class JSMTest(SingleJetStreamServerTestCase):

Expand Down