Skip to content

Commit e95bec8

Browse files
author
Nikita Kharlov
committed
max reconnect attempts
1 parent 20fcfc9 commit e95bec8

File tree

4 files changed

+63
-7
lines changed

4 files changed

+63
-7
lines changed

aio_pika/connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ async def main():
294294
query=kw
295295
)
296296

297-
connection = connection_class(url, loop=loop)
297+
connection = connection_class(url, loop=loop, **kwargs)
298298
await connection.connect(timeout=timeout)
299299
return connection
300300

aio_pika/exceptions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ class QueueEmpty(AMQPError, asyncio.QueueEmpty):
3939
pass
4040

4141

42+
class MaxReconnectAttemptsReached(Exception):
43+
pass
44+
45+
4246
__all__ = (
4347
'AMQPChannelError',
4448
'AMQPConnectionError',
@@ -51,6 +55,7 @@ class QueueEmpty(AMQPError, asyncio.QueueEmpty):
5155
'DuplicateConsumerTag',
5256
'IncompatibleProtocolError',
5357
'InvalidFrameError',
58+
'MaxReconnectAttemptsReached',
5459
'MessageProcessError',
5560
'MethodNotImplemented',
5661
'ProbableAuthenticationError',

aio_pika/robust_connection.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Callable, Type
55

66
from aiormq.connection import parse_bool, parse_int
7-
from .exceptions import CONNECTION_EXCEPTIONS
7+
from .exceptions import CONNECTION_EXCEPTIONS, MaxReconnectAttemptsReached
88
from .connection import Connection, connect, ConnectionType
99
from .tools import CallbackCollection
1010
from .types import TimeoutType
@@ -29,6 +29,7 @@ class RobustConnection(Connection):
2929

3030
CHANNEL_CLASS = RobustChannel
3131
KWARGS_TYPES = (
32+
('max_reconnect_attempts', parse_int, '0'),
3233
('reconnect_interval', parse_int, '5'),
3334
('fail_fast', parse_bool, '1'),
3435
)
@@ -42,7 +43,9 @@ def __init__(self, url, loop=None, **kwargs):
4243
self.fail_fast = self.kwargs['fail_fast']
4344

4445
self.__channels = set()
46+
self._reconnect_attempt = None
4547
self._on_reconnect_callbacks = CallbackCollection()
48+
self._on_stop_callbacks = CallbackCollection()
4649
self._closed = False
4750

4851
@property
@@ -76,6 +79,9 @@ def add_reconnect_callback(self, callback: Callable[[], None]):
7679

7780
self._on_reconnect_callbacks.add(callback)
7881

82+
def add_stop_callback(self, callback: Callable[[Exception], None]):
83+
self._on_stop_callbacks.add(callback)
84+
7985
async def connect(self, timeout: TimeoutType = None):
8086
while True:
8187
try:
@@ -97,6 +103,16 @@ async def reconnect(self):
97103
if self.is_closed:
98104
return
99105

106+
if self.kwargs['max_reconnect_attempts'] > 0:
107+
if self._reconnect_attempt is None:
108+
self._reconnect_attempt = 1
109+
else:
110+
self._reconnect_attempt += 1
111+
112+
if self._reconnect_attempt > self.kwargs['max_reconnect_attempts']:
113+
await self.close(MaxReconnectAttemptsReached())
114+
return
115+
100116
try:
101117
await super().connect()
102118
except CONNECTION_EXCEPTIONS:
@@ -124,6 +140,7 @@ def channel(self, channel_number: int = None,
124140
return channel
125141

126142
async def _on_reconnect(self):
143+
self._reconnect_attempt = None
127144
for number, channel in self._channels.items():
128145
try:
129146
await channel.on_reconnect(self, number)
@@ -144,6 +161,7 @@ async def close(self, exc=asyncio.CancelledError):
144161
return
145162

146163
self._closed = True
164+
self._on_stop_callbacks(exc)
147165

148166
if self.connection is None:
149167
return

tests/test_amqp_robust.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from aiormq import ChannelLockedResource
99

1010
from aio_pika import connect_robust, Message
11+
from aio_pika.exceptions import MaxReconnectAttemptsReached
1112
from aio_pika.robust_channel import RobustChannel
1213
from aio_pika.robust_connection import RobustConnection
1314
from aio_pika.robust_queue import RobustQueue
@@ -27,6 +28,7 @@ def __init__(self, *, loop, shost='127.0.0.1', sport,
2728
self.src_port = sport
2829
self.dst_host = dhost
2930
self.dst_port = dport
31+
self._run_task = None
3032
self.connections = set()
3133

3234
async def _pipe(self, reader: asyncio.StreamReader,
@@ -54,14 +56,18 @@ async def handle_client(self, creader: asyncio.StreamReader,
5456
])
5557

5658
async def start(self):
57-
result = await asyncio.start_server(
59+
self._run_task = await asyncio.start_server(
5860
self.handle_client,
5961
host=self.src_host,
6062
port=self.src_port,
6163
loop=self.loop,
6264
)
6365

64-
return result
66+
async def stop(self):
67+
assert self._run_task is not None
68+
self._run_task.close()
69+
await self.disconnect()
70+
self._run_task = None
6571

6672
async def disconnect(self):
6773
tasks = list()
@@ -74,7 +80,8 @@ async def close(writer):
7480
writer = self.connections.pop() # type: asyncio.StreamWriter
7581
tasks.append(self.loop.create_task(close(writer)))
7682

77-
await asyncio.wait(tasks)
83+
if tasks:
84+
await asyncio.wait(tasks)
7885

7986

8087
class TestCase(AMQPTestCase):
@@ -86,7 +93,7 @@ def get_unused_port() -> int:
8693
sock.close()
8794
return port
8895

89-
async def create_connection(self, cleanup=True):
96+
async def create_connection(self, cleanup=True, max_reconnect_attempts=0):
9097
self.proxy = Proxy(
9198
dhost=AMQP_URL.host,
9299
dport=AMQP_URL.port,
@@ -100,7 +107,11 @@ async def create_connection(self, cleanup=True):
100107
self.proxy.src_host
101108
).with_port(
102109
self.proxy.src_port
103-
).update_query(reconnect_interval=1)
110+
).update_query(
111+
reconnect_interval=1
112+
).update_query(
113+
max_reconnect_attempts=max_reconnect_attempts
114+
)
104115

105116
client = await connect_robust(str(url), loop=self.loop)
106117

@@ -212,6 +223,28 @@ async def reader():
212223

213224
assert len(shared) == 10
214225

226+
async def test_robust_reconnect_max_attempts(self):
227+
client = await self.create_connection(max_reconnect_attempts=2)
228+
self.assertIsInstance(client, RobustConnection)
229+
230+
first_close = asyncio.Future()
231+
stopped = asyncio.Future()
232+
233+
def stop_callback(exc):
234+
assert isinstance(exc, MaxReconnectAttemptsReached)
235+
stopped.set_result(True)
236+
237+
def close_callback(f):
238+
first_close.set_result(True)
239+
240+
client.add_stop_callback(stop_callback)
241+
client.connection.closing.add_done_callback(close_callback)
242+
await self.proxy.stop()
243+
await first_close
244+
# 1 interval before first try and 2 after attempts
245+
await asyncio.wait_for(stopped,
246+
timeout=client.reconnect_interval * 3 + 0.1)
247+
215248
async def test_channel_locked_resource2(self):
216249
ch1 = await self.create_channel()
217250
ch2 = await self.create_channel()

0 commit comments

Comments
 (0)