Skip to content

Commit 71c5a59

Browse files
author
Nikita Kharlov
committed
max reconnect attempts
1 parent 20fcfc9 commit 71c5a59

File tree

4 files changed

+72
-7
lines changed

4 files changed

+72
-7
lines changed

aio_pika/connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ async def main():
294294
query=kw
295295
)
296296

297-
connection = connection_class(url, loop=loop)
297+
connection = connection_class(url, loop=loop, **kwargs)
298298
await connection.connect(timeout=timeout)
299299
return connection
300300

aio_pika/exceptions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ class QueueEmpty(AMQPError, asyncio.QueueEmpty):
3939
pass
4040

4141

42+
class MaxReconnectAttemptsReached(Exception):
43+
pass
44+
45+
4246
__all__ = (
4347
'AMQPChannelError',
4448
'AMQPConnectionError',
@@ -51,6 +55,7 @@ class QueueEmpty(AMQPError, asyncio.QueueEmpty):
5155
'DuplicateConsumerTag',
5256
'IncompatibleProtocolError',
5357
'InvalidFrameError',
58+
'MaxReconnectAttemptsReached',
5459
'MessageProcessError',
5560
'MethodNotImplemented',
5661
'ProbableAuthenticationError',

aio_pika/robust_connection.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Callable, Type
55

66
from aiormq.connection import parse_bool, parse_int
7-
from .exceptions import CONNECTION_EXCEPTIONS
7+
from .exceptions import CONNECTION_EXCEPTIONS, MaxReconnectAttemptsReached
88
from .connection import Connection, connect, ConnectionType
99
from .tools import CallbackCollection
1010
from .types import TimeoutType
@@ -29,6 +29,7 @@ class RobustConnection(Connection):
2929

3030
CHANNEL_CLASS = RobustChannel
3131
KWARGS_TYPES = (
32+
('max_reconnect_attempts', parse_int, '0'),
3233
('reconnect_interval', parse_int, '5'),
3334
('fail_fast', parse_bool, '1'),
3435
)
@@ -41,8 +42,13 @@ def __init__(self, url, loop=None, **kwargs):
4142
self.reconnect_interval = self.kwargs['reconnect_interval']
4243
self.fail_fast = self.kwargs['fail_fast']
4344

45+
self._stop_future = self.loop.create_future()
46+
self._stop_future.add_done_callback(self._on_stop)
47+
4448
self.__channels = set()
49+
self._reconnect_attempt = None
4550
self._on_reconnect_callbacks = CallbackCollection()
51+
self._on_stop_callbacks = CallbackCollection()
4652
self._closed = False
4753

4854
@property
@@ -63,11 +69,18 @@ def _on_connection_close(self, connection, closing, *args, **kwargs):
6369

6470
super()._on_connection_close(connection, closing)
6571

72+
if isinstance(closing.exception(), MaxReconnectAttemptsReached):
73+
return
74+
6675
self.loop.call_later(
6776
self.reconnect_interval,
6877
lambda: self.loop.create_task(self.reconnect())
6978
)
7079

80+
def _on_stop(self, future):
81+
for cb in self._on_stop_callbacks:
82+
cb(future.exception())
83+
7184
def add_reconnect_callback(self, callback: Callable[[], None]):
7285
""" Add callback which will be called after reconnect.
7386
@@ -76,6 +89,9 @@ def add_reconnect_callback(self, callback: Callable[[], None]):
7689

7790
self._on_reconnect_callbacks.add(callback)
7891

92+
def add_stop_callback(self, callback: Callable[[Exception], None]):
93+
self._on_stop_callbacks.add(callback)
94+
7995
async def connect(self, timeout: TimeoutType = None):
8096
while True:
8197
try:
@@ -97,6 +113,16 @@ async def reconnect(self):
97113
if self.is_closed:
98114
return
99115

116+
if self.kwargs['max_reconnect_attempts'] > 0:
117+
if self._reconnect_attempt is None:
118+
self._reconnect_attempt = 1
119+
else:
120+
self._reconnect_attempt += 1
121+
122+
if self._reconnect_attempt > self.kwargs['max_reconnect_attempts']:
123+
self._stop_future.set_exception(MaxReconnectAttemptsReached())
124+
return
125+
100126
try:
101127
await super().connect()
102128
except CONNECTION_EXCEPTIONS:
@@ -124,6 +150,7 @@ def channel(self, channel_number: int = None,
124150
return channel
125151

126152
async def _on_reconnect(self):
153+
self._reconnect_attempt = None
127154
for number, channel in self._channels.items():
128155
try:
129156
await channel.on_reconnect(self, number)

tests/test_amqp_robust.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from aiormq import ChannelLockedResource
99

1010
from aio_pika import connect_robust, Message
11+
from aio_pika.exceptions import MaxReconnectAttemptsReached
1112
from aio_pika.robust_channel import RobustChannel
1213
from aio_pika.robust_connection import RobustConnection
1314
from aio_pika.robust_queue import RobustQueue
@@ -27,6 +28,7 @@ def __init__(self, *, loop, shost='127.0.0.1', sport,
2728
self.src_port = sport
2829
self.dst_host = dhost
2930
self.dst_port = dport
31+
self._run_task = None
3032
self.connections = set()
3133

3234
async def _pipe(self, reader: asyncio.StreamReader,
@@ -54,14 +56,18 @@ async def handle_client(self, creader: asyncio.StreamReader,
5456
])
5557

5658
async def start(self):
57-
result = await asyncio.start_server(
59+
self._run_task = await asyncio.start_server(
5860
self.handle_client,
5961
host=self.src_host,
6062
port=self.src_port,
6163
loop=self.loop,
6264
)
6365

64-
return result
66+
async def stop(self):
67+
assert self._run_task is not None
68+
self._run_task.close()
69+
await self.disconnect()
70+
self._run_task = None
6571

6672
async def disconnect(self):
6773
tasks = list()
@@ -74,7 +80,8 @@ async def close(writer):
7480
writer = self.connections.pop() # type: asyncio.StreamWriter
7581
tasks.append(self.loop.create_task(close(writer)))
7682

77-
await asyncio.wait(tasks)
83+
if tasks:
84+
await asyncio.wait(tasks)
7885

7986

8087
class TestCase(AMQPTestCase):
@@ -86,7 +93,7 @@ def get_unused_port() -> int:
8693
sock.close()
8794
return port
8895

89-
async def create_connection(self, cleanup=True):
96+
async def create_connection(self, cleanup=True, max_reconnect_attempts=0):
9097
self.proxy = Proxy(
9198
dhost=AMQP_URL.host,
9299
dport=AMQP_URL.port,
@@ -100,7 +107,11 @@ async def create_connection(self, cleanup=True):
100107
self.proxy.src_host
101108
).with_port(
102109
self.proxy.src_port
103-
).update_query(reconnect_interval=1)
110+
).update_query(
111+
reconnect_interval=1
112+
).update_query(
113+
max_reconnect_attempts=max_reconnect_attempts
114+
)
104115

105116
client = await connect_robust(str(url), loop=self.loop)
106117

@@ -212,6 +223,28 @@ async def reader():
212223

213224
assert len(shared) == 10
214225

226+
async def test_robust_reconnect_max_attempts(self):
227+
client = await self.create_connection(max_reconnect_attempts=2)
228+
self.assertIsInstance(client, RobustConnection)
229+
230+
first_close = asyncio.Future()
231+
stopped = asyncio.Future()
232+
233+
def stop_callback(exc):
234+
assert isinstance(exc, MaxReconnectAttemptsReached)
235+
stopped.set_result(True)
236+
237+
def close_callback(f):
238+
first_close.set_result(True)
239+
240+
client.add_stop_callback(stop_callback)
241+
client.connection.closing.add_done_callback(close_callback)
242+
await self.proxy.stop()
243+
await first_close
244+
# 1 interval before first try and 2 after attempts
245+
await asyncio.wait_for(stopped,
246+
timeout=client.reconnect_interval * 3 + 0.1)
247+
215248
async def test_channel_locked_resource2(self):
216249
ch1 = await self.create_channel()
217250
ch2 = await self.create_channel()

0 commit comments

Comments
 (0)