Skip to content

Commit fee1abe

Browse files
committed
add robust_connection_rrhost.py
1 parent c281314 commit fee1abe

File tree

2 files changed

+166
-0
lines changed

2 files changed

+166
-0
lines changed
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from typing import List, Optional, Any
2+
from yarl import URL
3+
from aio_pika.robust_connection import RobustConnection
4+
from aio_pika.connection import make_url
5+
from urllib.parse import urlparse
6+
from .log import get_logger
7+
8+
log = get_logger(__name__)
9+
10+
11+
class RobustConnectionRRHost:
12+
"""
13+
Robust AMQP connection with round-robin host selection.
14+
15+
This class manages a single RobustConnection instance internally,
16+
cycling through provided URLs until a successful connection is made.
17+
"""
18+
19+
def __init__(self, urls: List[str], default_port: int = 5672,
20+
**kwargs: Any):
21+
"""
22+
Initialize with a list of broker URLs, normalizing and applying default port if missing.
23+
24+
:param urls: List of AMQP broker URLs as strings.
25+
:param default_port: Default port used if not specified in URLs.
26+
:param kwargs: Additional arguments passed to RobustConnection.
27+
"""
28+
self.urls = []
29+
for url in urls:
30+
parsed = urlparse(url)
31+
if not parsed.scheme:
32+
url = f"amqp://{url}"
33+
url_obj = make_url(url)
34+
if not url_obj.host:
35+
raise ValueError(f"Host missing in URL {url_obj}")
36+
if url_obj.port is None:
37+
url_obj = URL.build(
38+
scheme=url_obj.scheme,
39+
user=url_obj.user,
40+
password=url_obj.password,
41+
host=url_obj.host,
42+
port=default_port,
43+
path=url_obj.path,
44+
query=url_obj.query,
45+
fragment=url_obj.fragment,
46+
)
47+
self.urls.append(url_obj)
48+
self._current_index = 0
49+
self._kwargs = kwargs
50+
self._connection: Optional[
51+
RobustConnection
52+
] = None # Active connection instance, None if disconnected
53+
self._connect_timeout = None # Timeout used for connection attempts
54+
55+
async def connect(self, timeout: Optional[float] = None) -> None:
56+
"""
57+
Attempt to connect to one of the provided URLs in round-robin order.
58+
59+
:param timeout: Optional connection timeout in seconds.
60+
:raises Exception: Raises the last exception if all connection attempts fail.
61+
"""
62+
self._connect_timeout = timeout
63+
last_exc = None
64+
for _ in range(len(self.urls)):
65+
url = str(self.urls[self._current_index])
66+
try:
67+
self._connection = RobustConnection(url, **self._kwargs)
68+
await self._connection.connect(timeout=timeout)
69+
return
70+
except Exception as e:
71+
last_exc = e
72+
self._current_index = (self._current_index + 1) % len(
73+
self.urls)
74+
raise last_exc or RuntimeError("All connection attempts failed")
75+
76+
async def _on_connection_close(self, closing) -> None:
77+
"""
78+
Internal callback triggered on connection close to attempt reconnection.
79+
"""
80+
if self._connection and not self._connection.is_closed:
81+
await self.reconnect()
82+
if self._connection:
83+
await self._connection._on_connection_close(closing)
84+
85+
async def reconnect(self) -> None:
86+
"""
87+
Perform reconnection to the next URL in round-robin order.
88+
"""
89+
self._current_index = (self._current_index + 1) % len(self.urls)
90+
try:
91+
await self.connect(timeout=self._connect_timeout)
92+
if self._connection:
93+
await self._connection.reconnect_callbacks()
94+
except Exception as e:
95+
log.info(
96+
f"Reconnect failed on {self.urls[self._current_index]}: {e}")
97+
98+
def __getattr__(self, name: str) -> Any:
99+
if self._connection:
100+
return getattr(self._connection, name)
101+
raise AttributeError(
102+
f"'RobustConnectionRRHost' object has no attribute '{name}'")
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import pytest
2+
from yarl import URL
3+
from aio_pika.robust_connection_rrhost import RobustConnectionRRHost
4+
5+
@pytest.mark.asyncio
6+
async def test_connect_with_rabbitmq_container(amqp_url):
7+
urls = [str(amqp_url)]
8+
conn = RobustConnectionRRHost(urls)
9+
await conn.connect(timeout=5)
10+
assert not conn.is_closed
11+
await conn.close()
12+
13+
@pytest.mark.asyncio
14+
async def test_failover_with_rabbitmq_container(amqp_url):
15+
invalid_url = "amqp://guest:guest@invalidhost:5672/"
16+
urls = [invalid_url, str(amqp_url)]
17+
conn = RobustConnectionRRHost(urls)
18+
await conn.connect(timeout=5)
19+
assert not conn.is_closed
20+
await conn.close()
21+
22+
@pytest.mark.asyncio
23+
async def test_amqp_scheme_with_rabbitmq(amqp_url):
24+
url = f"amqp://guest:guest@{amqp_url.host}:5672/"
25+
conn = RobustConnectionRRHost([url])
26+
assert conn.urls[0].scheme == "amqp"
27+
await conn.connect(timeout=5)
28+
assert not conn.is_closed
29+
await conn.close()
30+
31+
@pytest.mark.asyncio
32+
@pytest.mark.skip(reason="AMQPS non configurato nel server di test")
33+
async def test_amqps_scheme_with_rabbitmq(amqp_url):
34+
url = f"amqps://guest:guest@{amqp_url.host}:5671/"
35+
conn = RobustConnectionRRHost([url])
36+
await conn.connect(timeout=5)
37+
assert not conn.is_closed
38+
await conn.close()
39+
40+
@pytest.mark.asyncio
41+
async def test_no_scheme_defaults_to_amqp(amqp_url):
42+
raw_url = f"guest:guest@{amqp_url.host}:5672"
43+
url = f"amqp://{raw_url}"
44+
parsed = URL(url)
45+
if parsed.port is None:
46+
parsed = parsed.with_port(5672)
47+
conn = RobustConnectionRRHost([str(parsed)])
48+
assert conn.urls[0].scheme == "amqp"
49+
await conn.connect(timeout=5)
50+
assert not conn.is_closed
51+
await conn.close()
52+
53+
@pytest.mark.asyncio
54+
async def test_host_and_port_only(amqp_url):
55+
raw_url = f"{amqp_url.host}:5672"
56+
url = f"amqp://{raw_url}"
57+
parsed = URL(url)
58+
if parsed.port is None:
59+
parsed = parsed.with_port(5672)
60+
conn = RobustConnectionRRHost([str(parsed)])
61+
assert conn.urls[0].host == amqp_url.host
62+
await conn.connect(timeout=5)
63+
assert not conn.is_closed
64+
await conn.close()

0 commit comments

Comments
 (0)