diff --git a/README.rst b/README.rst index eb61c0d579..b90ec6c89c 100644 --- a/README.rst +++ b/README.rst @@ -70,6 +70,7 @@ and the `Wikipedia article about AMQP`_. .. _`Pyro`: https://pyro4.readthedocs.io/ .. _`SoftLayer MQ`: https://sldn.softlayer.com/reference/messagequeueapi .. _`MongoDB`: https://www.mongodb.com/ +.. _`AWS SNS`: https://aws.amazon.com/sns/ .. _transport-comparison: @@ -102,9 +103,9 @@ Transport Comparison .. [#f1] Declarations only kept in memory, so exchanges/queues must be declared by all clients that needs them. -.. [#f2] Fanout supported via storing routing tables in SimpleDB. - Disabled by default, but can be enabled by using the - ``supports_fanout`` transport option. +.. [#f2] Fanout is supported via `AWS SNS`_. A notification is sent to SNS, and a copy is set to all subscribed + `Amazon SQS`_ queues. Please consult the AWS SNS and SQS pricing pages to see how this will affect your usage + costs. Disabled by default, but can be enabled by using the ``supports_fanout`` transport option. .. [#f3] AMQP Message priority support depends on broker implementation. diff --git a/docs/reference/kombu.transport.SQS.rst b/docs/reference/kombu.transport.SQS.rst index e5136ebdc6..4d44cc03ad 100644 --- a/docs/reference/kombu.transport.SQS.rst +++ b/docs/reference/kombu.transport.SQS.rst @@ -23,6 +23,7 @@ :members: :undoc-members: + Back-off policy ------------------------ Back-off policy is using SQS visibility timeout mechanism altering the time difference between task retries. @@ -70,4 +71,22 @@ Message Attributes SQS supports sending message attributes along with the message body. To use this feature, you can pass a 'message_attributes' as keyword argument -to `basic_publish` method. \ No newline at end of file +to `basic_publish` method. + + +Amazon SQS Transport - ``kombu.transport.SQS.exceptions`` +================================================ + +.. automodule:: kombu.transport.SQS.exceptions + :members: + :show-inheritance: + :undoc-members: + + +Amazon SQS Transport - ``kombu.transport.SQS.SNS`` +================================================ + +.. automodule:: kombu.transport.SQS.SNS + :members: + :show-inheritance: + :undoc-members: diff --git a/kombu/transport/SQS/SNS.py b/kombu/transport/SQS/SNS.py new file mode 100644 index 0000000000..ceb02190e5 --- /dev/null +++ b/kombu/transport/SQS/SNS.py @@ -0,0 +1,568 @@ +"""Amazon SNS fanout support for the AWS SQS transport module for Kombu. + +This module provides a `SNS` class that can be used to manage SNS topics and subscriptions. +It's primarily used to provide fanout support via AWS Simple Notification Service (SNS) +topics and subscriptions. The module also provides methods for handling the lifecycle +of these topics. +""" +from __future__ import annotations + +import json +import threading +from datetime import datetime +from typing import TYPE_CHECKING + +from botocore.exceptions import ClientError + +from kombu.exceptions import KombuError +from kombu.log import get_logger + +from .exceptions import UndefinedExchangeException + +# pragma: no branch +if TYPE_CHECKING: + from . import Channel + +logger = get_logger(__name__) + + +class SNS: + """A class to manage AWS Simple Notification Service (SNS) for fanout exchanges. + + This class maintains caches of SNS subscriptions, clients, topic ARNs etc to + enable efficient management of SNS topics and subscriptions. + """ + def __init__(self, channel: Channel): + self.channel = channel + self._client = None + self.subscriptions = _SnsSubscription(self) + self._predefined_clients = {} # A client for each predefined queue + self._topic_arn_cache: dict[str, str] = {} # SNS topic name => Topic ARN + self._exchange_topic_cache: dict[str, str] = {} # Exchange name => SNS topic ARN + self.sts_expiration: datetime | None = None # Cached STS expiration time + self._lock = threading.Lock() + + def initialise_exchange(self, exchange_name: str) -> None: + """Initialise SNS topic for a fanout exchange. + + This method will create the SNS topic if it doesn't exist, and check for any SNS topic subscriptions + that no longer exist. + + :param exchange_name: The name of the exchange. + :returns: None + """ + # Clear any old subscriptions + self.subscriptions.cleanup(exchange_name) + + # If topic has already been initialised, then do nothing + if self._topic_arn_cache.get(exchange_name): + return None + + # If predefined_exchanges are set, then do not try to create an SNS topic + if self.channel.predefined_exchanges: + logger.debug( + "'predefined_exchanges' has been specified, so SNS topics will" + " not be created." + ) + return + + with self._lock: + # Create the topic and cache the ARN + self._topic_arn_cache[exchange_name] = self._create_sns_topic(exchange_name) + return None + + def publish( + self, + exchange_name: str, + message: str, + message_attributes: dict = None, + request_params: dict = None, + ) -> None: + """Send a notification to AWS Simple Notification Service (SNS). + + :param exchange_name: The name of the exchange. + :param message: The message to be sent as a JSON string + :param message_attributes: Attributes for the message. + :param request_params: Additional parameters for SNS notification. + :return: None + """ + # Get topic ARN for the given exchange + topic_arn = self._get_topic_arn(exchange_name) + + # Build request args for boto + request_args: dict[str, str | dict] = { + "TopicArn": topic_arn, + "Message": message, + } + request_args.update(request_params or {}) + + # Serialise message attributes into SNS format + if serialised_attrs := self.serialise_message_attributes(message_attributes): + request_args["MessageAttributes"] = serialised_attrs + + # Send event to topic + response = self.get_client(exchange_name).publish(**request_args) + if (status_code := response["ResponseMetadata"]["HTTPStatusCode"]) != 200: + raise UndefinedExchangeException( + f"Unable to send message to topic '{topic_arn}': status code was {status_code}" + ) + + def _get_topic_arn(self, exchange_name: str) -> str: + """Get the SNS topic ARN. + + If the topic ARN is not in the cache, then create it + :param exchange_name: The exchange to create the SNS topic for + :return: The SNS topic ARN + """ + # If topic ARN is in the cache, then return it + if topic_arn := self._topic_arn_cache.get(exchange_name): + return topic_arn + + # If predefined-exchanges are used, then do not create a new topic and raise an exception + if self.channel.predefined_exchanges: + with self._lock: + # Try and get the topic ARN from the predefined_exchanges and add it to the cache + topic_arn = self._topic_arn_cache[exchange_name] = ( + self.channel.predefined_exchanges.get(exchange_name, {}).get("arn") + ) + if topic_arn: + return topic_arn + + # If pre-defined exchanges do not have the exchange, then raise an exception + raise UndefinedExchangeException( + f"Exchange with name '{exchange_name}' must be defined in 'predefined_exchanges'." + ) + + # If predefined_caches are not used, then create a new SNS topic/retrieve the ARN from AWS SNS and cache it + with self._lock: + arn = self._topic_arn_cache[exchange_name] = self._create_sns_topic( + exchange_name + ) + return arn + + def _create_sns_topic(self, exchange_name: str) -> str: + """Creates an AWS SNS topic. + + If the topic already exists, AWS will return it's ARN without creating a new one. + + :param exchange_name: The exchange to create the SNS topic for + :return: Topic ARN + """ + # Create the SNS topic/Retrieve the SNS topic ARN + topic_name = self.channel.canonical_queue_name(exchange_name) + + logger.debug(f"Creating SNS topic '{topic_name}'") + + # Call SNS API to create the topic + response = self.get_client().create_topic( + Name=topic_name, + Attributes={ + "FifoTopic": str(topic_name.endswith(".fifo")), + }, + Tags=[ + {"Key": "ManagedBy", "Value": "Celery/Kombu"}, + { + "Key": "Description", + "Value": "This SNS topic is used by Kombu to enable Fanout support for AWS SQS.", + }, + ], + ) + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise UndefinedExchangeException("Unable to create SNS topic") + + # Extract the ARN from the response + arn = response["TopicArn"] + logger.info(f"Created SNS topic '{topic_name}' with ARN '{arn}'") + + return arn + + @staticmethod + def serialise_message_attributes(message_attributes: dict) -> dict: + """Serialises SQS message attributes into SNS format. + + :param message_attributes: A dictionary of message attributes + :returns: A dictionary of serialised message attributes in SNS format. + """ + if not message_attributes: + return {} + + attrs = {} + for key, value in message_attributes.items(): + attrs[key] = { + "DataType": "String", + "StringValue": str(value), + } + + return attrs + + def get_client(self, exchange_name: str | None = None): + """Get or create a Boto SNS client. + + If an SNS client has already been initialised for this Channel instance, return it. If not, create a new SNS + client, add it to this Channel instance and return it. + + If the exchange is defined in the predefined_exchanges, then return the client for the exchange and handle + any STS token renewal. + + :param exchange_name: The name of the exchange + :returns: A Boto SNS client. + """ + # Attempt to get predefined client for exchange if it has been provided + if exchange_name is not None and self.channel.predefined_exchanges: + # Raise if queue is not defined + if not (e := self.channel.predefined_exchanges.get(exchange_name)): + raise UndefinedExchangeException( + f"Exchange with name '{exchange_name}' must be defined in 'predefined_exchanges'." + ) + + # Handle authenticating boto client with tokens + if self.channel.transport_options.get("sts_role_arn"): + return self._handle_sts_session(exchange_name, e) + + # If the queue has already been defined, then return the client for the queue + if c := self._predefined_clients.get(exchange_name): + return c + + # Create client, add it to the queue map and return + c = self._predefined_clients[exchange_name] = self._create_boto_client( + region=e.get("region", self.channel.region), + access_key_id=e.get("access_key_id", self.channel.conninfo.userid), + secret_access_key=e.get( + "secret_access_key", self.channel.conninfo.password + ), + ) + return c + + # If SQS client has been initialised, return it + if self._client is not None: + return self._client + + # Initialise a new SQS client and return it + c = self._client = self._create_boto_client( + region=self.channel.region, + access_key_id=self.channel.conninfo.userid, + secret_access_key=self.channel.conninfo.password, + ) + return c + + def _handle_sts_session(self, exchange_name: str, e: dict): + """Checks if the STS token needs renewing for SNS. + + :param exchange_name: The exchange name + :param e: The exchange object + :returns: The SNS client with a refreshed STS token + """ + # Check if a token refresh is needed + if self.channel.is_sts_token_refresh_required( + name=exchange_name, + client_map=self._predefined_clients, + expire_time=self.sts_expiration, + ): + return self._create_boto_client_with_sts_session( + exchange_name, region=e.get("region", self.channel.region) + ) + + # If token refresh is not required, return existing client + return self._predefined_clients[exchange_name] + + def _create_boto_client_with_sts_session(self, exchange_name: str, region: str): + """Creates a new SNS client with a refreshed STS token. + + :param exchange_name: The exchange name + :param region: The AWS region to use. + :returns: The SNS client with a refreshed STS token. + """ + # Handle STS token refresh + sts_creds = self.channel.get_sts_credentials() + self.sts_expiration = sts_creds["Expiration"] + + # Get new client and return it + c = self._predefined_clients[exchange_name] = self._create_boto_client( + region=region, + access_key_id=sts_creds["AccessKeyId"], + secret_access_key=sts_creds["SecretAccessKey"], + session_token=sts_creds["SessionToken"], + ) + return c + + def _create_boto_client( + self, region, access_key_id, secret_access_key, session_token=None + ): + """Create a new SNS client. + + :param region: The AWS region to use. + :param access_key_id: The AWS access key ID for authenticating with boto. + :param secret_access_key: The AWS secret access key for authenticating with boto. + :param session_token: The AWS session token for authenticating with boto, if required. + :returns: A Boto SNS client. + """ + return self.channel._new_boto_client( + service="sns", + region=region, + access_key_id=access_key_id, + secret_access_key=secret_access_key, + session_token=session_token, + ) + + # --------------------------------- + # SNS topic subscription management + # --------------------------------- + + +class _SnsSubscription: + _queue_arn_cache: dict[str, str] = {} # SQS queue URL => Queue ARN + _subscription_arn_cache: dict[str, str] = {} # Queue => Subscription ARN + + _lock = threading.Lock() + + def __init__(self, sns_fanout: SNS): + self.sns = sns_fanout + + def subscribe_queue(self, queue_name: str, exchange_name: str) -> str: + """Subscribes a queue to an AWS SNS topic. + + :param queue_name: The queue to subscribe + :param exchange_name: The exchange to subscribe to the queue, if not provided + :raises: UndefinedExchangeException if exchange is not defined. + :return: The subscription ARN + """ + # Get exchange from Queue and raise if not defined + cache_key = f"{exchange_name}:{queue_name}" + + # If the subscription ARN is already cached, return it + if subscription_arn := self._subscription_arn_cache.get(cache_key): + return subscription_arn + + # Get ARNs for queue and topic + queue_arn = self._get_queue_arn(queue_name) + topic_arn = self.sns._get_topic_arn(exchange_name) + + # Subscribe the SQS queue to the SNS topic + subscription_arn = self._subscribe_queue_to_sns_topic( + queue_arn=queue_arn, topic_arn=topic_arn + ) + + # Setup permissions for the queue to receive messages from the topic + self._set_permission_on_sqs_queue( + topic_arn=topic_arn, queue_arn=queue_arn, queue_name=queue_name + ) + + # Update subscription ARN cache + with self._lock: + self._subscription_arn_cache[cache_key] = subscription_arn + + return subscription_arn + + def unsubscribe_queue(self, queue_name: str, exchange_name: str) -> None: + """Unsubscribes a queue from an AWS SNS topic. + + :param queue_name: The queue to unsubscribe + :param exchange_name: The exchange to unsubscribe from the queue, if not provided + :return: None + """ + cache_key = f"{exchange_name}:{queue_name}" + # Get subscription ARN from cache if it exists, and return if it exists + if not (subscription_arn := self._subscription_arn_cache.get(cache_key)): + return + + # Unsubscribe the SQS queue from the SNS topic + self._unsubscribe_sns_subscription(subscription_arn) + logger.info( + f"Unsubscribed subscription '{subscription_arn}' for SQS queue '{queue_name}'" + ) + + def cleanup(self, exchange_name: str) -> None: + """Removes any stale SNS topic subscriptions. + + This method will check that any SQS subscriptions on the SNS topic are associated with SQS queues. If not, + it will remove the stale subscription. + + :param exchange_name: The exchange to check for stale subscriptions + :return: None + """ + # If predefined_exchanges are set, then do not try to remove subscriptions + if self.sns.channel.predefined_exchanges: + logger.debug( + "'predefined_exchanges' has been specified, so stale SNS subscription cleanup will be skipped." + ) + return + + logger.debug( + f"Checking for stale SNS subscriptions for exchange '{exchange_name}'" + ) + + # Get subscriptions to check + topic_arn = self.sns._get_topic_arn(exchange_name) + + # Iterate through the subscriptions and remove any that are not associated with SQS queues + for subscription_arn in self._get_invalid_sns_subscriptions(topic_arn): + # Unsubscribe the SQS queue from the SNS topic + try: + self._unsubscribe_sns_subscription(subscription_arn) + logger.info( + f"Removed stale subscription '{subscription_arn}' for SNS topic '{topic_arn}'" + ) + + # Report any failures to remove the subscription and continue to the next as this is not a critical error + except Exception as e: + logger.warning( + f"Failed to remove stale subscription '{subscription_arn}' for SNS topic '{topic_arn}': {e}" + ) + + def _subscribe_queue_to_sns_topic(self, queue_arn: str, topic_arn: str) -> str: + """Subscribes a queue to an AWS SNS topic. + + :param queue_arn: The ARN of the queue to subscribe + :param topic_arn: The ARN of the SNS topic to subscribe to + :raises: UndefinedExchangeException if exchange is not defined. + :return: The subscription ARN + """ + logger.debug(f"Subscribing queue '{queue_arn}' to SNS topic '{topic_arn}'") + + # Request SNS client to subscribe the queue to the topic + response = self.sns.get_client().subscribe( + TopicArn=topic_arn, + Protocol="sqs", + Endpoint=queue_arn, + Attributes={"RawMessageDelivery": "true"}, + ReturnSubscriptionArn=True, + ) + if (status_code := response["ResponseMetadata"]["HTTPStatusCode"]) != 200: + raise Exception(f"Unable to subscribe queue: status code was {status_code}") + + # Extract the subscription ARN from the response and log + subscription_arn = response["SubscriptionArn"] + logger.info( + f"Create subscription '{subscription_arn}' for SQS queue '{queue_arn}' to" + f" SNS topic '{topic_arn}'" + ) + + return subscription_arn + + def _set_permission_on_sqs_queue( + self, topic_arn: str, queue_name: str, queue_arn: str + ): + """Sets the permissions on an AWS SQS queue to enable the SNS topic to publish to the queue. + + :param topic_arn: The ARN of the SNS topic + :param queue_name: The queue name to set permissions for + :param queue_arn: The ARN of the SQS queue + :return: None + """ + self.sns.channel.sqs().set_queue_attributes( + QueueUrl=self.sns.channel._resolve_queue_url(queue_name), + Attributes={ + "Policy": json.dumps( + { + "Version": "2012-10-17", + "Statement": [ + { + "Sid": "KombuManaged", + "Effect": "Allow", + "Principal": {"Service": "sns.amazonaws.com"}, + "Action": "SQS:SendMessage", + "Resource": queue_arn, + "Condition": {"ArnLike": {"aws:SourceArn": topic_arn}}, + } + ], + } + ) + }, + ) + logger.debug(f"Set permissions on SNS topic '{topic_arn}'") + + def _unsubscribe_sns_subscription(self, subscription_arn: str) -> None: + """Unsubscribes a subscription from an AWS SNS topic. + + :param subscription_arn: The ARN of the subscription to unsubscribe + :return: None + """ + response = self.sns.get_client().unsubscribe(SubscriptionArn=subscription_arn) + if (status_code := response["ResponseMetadata"]["HTTPStatusCode"]) != 200: + logger.error( + f"Unable to remove subscription '{subscription_arn}': status code was {status_code}" + ) + + def _get_invalid_sns_subscriptions(self, sns_topic_arn: str) -> list[str]: + """Get a list of all invalid SQS subscriptions associated with a given SNS topic. + + :param sns_topic_arn: The SNS topic ARN to check + :return: A list of SNS subscription ARNs that are invalid + """ + paginator = self.sns.get_client().get_paginator("list_subscriptions_by_topic") + + # Iterate through the paginated subscriptions and build a list of subscriptions to check + invalid_subscription_arns = [] + for response in paginator.paginate(TopicArn=sns_topic_arn): + invalid_subscription_arns.extend( + self._filter_sns_subscription_response(response.get("Subscriptions")) + ) + + return invalid_subscription_arns + + def _filter_sns_subscription_response(self, subscriptions: list[dict]) -> list[str]: + """Returns a list of SNS subscription ARNs that are not associated with a SQS queue. + + :param subscriptions: A list of subscriptions for an SNS topic + :return: A list of subscription ARNs that are dead + """ + subscription_arns = [] + + # If the subscriptions list is empty or None, return an empty list + if not subscriptions: + return subscription_arns + + # Iterate through the subscriptions and check if the queue is valid + for subscription in subscriptions: + # Skip subscription if it is not for SQS + if not subscription.get("Protocol", "").lower() == "sqs": + continue + + # Extract the SQS queue ARN from the subscription endpoint + queue_name = subscription["Endpoint"].split(":")[-1] + + # Check if the queue has been removed by calling the get queue URL method. + # Note: listing the queues sometimes results in a valid queue not being + # returned (due to eventual consistency in SQS), so calling this method + # helps to mitigate this. + try: + self.sns.channel.sqs().get_queue_url(QueueName=queue_name) + except ClientError as e: + queue_missing_errs = ["QueueDoesNotExist", "NonExistentQueue"] + # If one of the errors above has been raised, then the queue has been + # removed and the subscription should be removed too. + if any(err in str(e) for err in queue_missing_errs): + subscription_arns.append(subscription["SubscriptionArn"]) + else: + raise + + return subscription_arns + + def _get_queue_arn(self, queue_name: str) -> str: + """Returns the ARN of the SQS queue associated with the given queue. + + This method will return the ARN from the cache if it exists, otherwise it will fetch it from SQS. + + :param queue_name: The queue to get the ARN for + """ + # Check if the queue ARN is already cached, and return if it exists + if arn := self._queue_arn_cache.get(queue_name): + return arn + + queue_url = self.sns.channel._resolve_queue_url(queue_name) + + # Get the ARN for the SQS queue + response = self.sns.channel.sqs().get_queue_attributes( + QueueUrl=queue_url, AttributeNames=["QueueArn"] + ) + if (status_code := response["ResponseMetadata"]["HTTPStatusCode"]) != 200: + raise KombuError( + f"Unable to get ARN for SQS queue '{queue_name}': " + f"status code was '{status_code}'" + ) + + # Update queue ARN cache + with self._lock: + arn = self._queue_arn_cache[queue_name] = response["Attributes"]["QueueArn"] + + return arn diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS/__init__.py similarity index 71% rename from kombu/transport/SQS.py rename to kombu/transport/SQS/__init__.py index cfe2fe0148..d5c89f2935 100644 --- a/kombu/transport/SQS.py +++ b/kombu/transport/SQS/__init__.py @@ -93,6 +93,32 @@ to access with. sts_token_timeout is the token timeout, defaults (and minimum) to 900 seconds. After the mentioned period, a new token will be created. +-------------------- +Predefined Exchanges +-------------------- +When using a fanout exchange with this transport, messages are sent to an AWS SNS, which then forwards the messages +to all subscribed queues. + +The default behavior of this transport is to create the SNS topic when the exchange is first declared. +However, it is also possible to use a predefined SNS topic instead of letting the transport create it. + +.. code-block:: python + + transport_options = { + 'predefined_exchanges': { + 'exchange-1': { + 'arn': 'arn:aws:sns:us-east-1:xxx:exchange-1', + 'access_key_id': 'a', + 'secret_access_key': 'b', + }, + 'exchange-2.fifo': { + 'arn': 'arn:aws:sns:us-east-1:xxx:exchange-2', + 'access_key_id': 'c', + 'secret_access_key': 'd', + }, + } + } + .. versionadded:: 5.6.0 sts_token_buffer_time (seconds) is the time by which you want to refresh your token earlier than its actual expiration time, defaults to 0 (no time buffer will be added), @@ -135,7 +161,6 @@ * Supports TTL: No """ - from __future__ import annotations import base64 @@ -147,23 +172,29 @@ from datetime import datetime, timedelta, timezone from json import JSONDecodeError from queue import Empty -from typing import Any +from typing import Any, Literal -from botocore.client import Config +from botocore.client import BaseClient, Config from botocore.exceptions import ClientError from vine import ensure_promise, promise, transform from kombu.asynchronous import get_event_loop -from kombu.asynchronous.aws.ext import boto3, exceptions +from kombu.asynchronous.aws.ext import boto3 +from kombu.asynchronous.aws.ext import exceptions as aws_exceptions from kombu.asynchronous.aws.sqs.connection import AsyncSQSConnection from kombu.asynchronous.aws.sqs.message import AsyncMessage from kombu.log import get_logger +from kombu.transport import virtual +from kombu.transport.SQS.exceptions import (AccessDeniedQueueException, + DoesNotExistQueueException, + InvalidQueueException, + UndefinedQueueException) from kombu.utils import scheduling from kombu.utils.encoding import bytes_to_str, safe_str from kombu.utils.json import dumps, loads from kombu.utils.objects import cached_property -from . import virtual +from .SNS import SNS logger = get_logger(__name__) @@ -177,35 +208,17 @@ #: SQS bulk get supports a maximum of 10 messages at a time. SQS_MAX_MESSAGES = 10 +_SUPPORTED_BOTO_SERVICES = Literal["sqs", "sns"] + def maybe_int(x): """Try to convert x' to int, or return x' if that fails.""" try: return int(x) - except ValueError: + except (TypeError, ValueError): return x -class UndefinedQueueException(Exception): - """Predefined queues are being used and an undefined queue was used.""" - - -class InvalidQueueException(Exception): - """Predefined queues are being used and configuration is not valid.""" - - -class AccessDeniedQueueException(Exception): - """Raised when access to the AWS queue is denied. - - This may occur if the permissions are not correctly set or the - credentials are invalid. - """ - - -class DoesNotExistQueueException(Exception): - """The specified queue doesn't exist.""" - - class QoS(virtual.QoS): """Quality of Service guarantees implementation for SQS.""" @@ -253,7 +266,7 @@ def extract_task_name_and_number_of_retries(self, delivery_tag): task_name = message_headers['task'] number_of_retries = int( message.properties['delivery_info']['sqs_message'] - ['Attributes']['ApproximateReceiveCount']) + ['Attributes']['ApproximateReceiveCount']) return task_name, number_of_retries @@ -267,12 +280,15 @@ class Channel(virtual.Channel): _asynsqs = None _predefined_queue_async_clients = {} # A client for each predefined queue _sqs = None + _fanout = None _predefined_queue_clients = {} # A client for each predefined queue _queue_cache = {} # SQS queue name => SQS queue URL _noack_queues = set() + QoS = QoS # https://stackoverflow.com/questions/475074/regex-to-parse-or-validate-base64-data - B64_REGEX = re.compile(rb'^(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?$') + B64_REGEX = re.compile( + rb'^(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?$') def __init__(self, *args, **kwargs): if boto3 is None: @@ -319,14 +335,15 @@ def _update_queue_cache(self, queue_name_prefix): queue_name = url.split('/')[-1] self._queue_cache[queue_name] = url - def basic_consume(self, queue, no_ack, *args, **kwargs): + def basic_consume(self, queue, no_ack, callback, consumer_tag, **kwargs): + # If using a Fanout exchange, then subscribe to the queue to SNS + self._subscribe_queue_to_fanout_exchange_if_required(queue) + if no_ack: self._noack_queues.add(queue) if self.hub: self._loop1(queue) - return super().basic_consume( - queue, no_ack, *args, **kwargs - ) + return super().basic_consume(queue, no_ack, callback, consumer_tag, **kwargs) def basic_cancel(self, consumer_tag): if consumer_tag in self._consumers: @@ -334,6 +351,11 @@ def basic_cancel(self, consumer_tag): self._noack_queues.discard(queue) return super().basic_cancel(consumer_tag) + def _queue_bind(self, exchange, routing_key, pattern, queue): + # If the exchange is a fanout exchange, initialise the SNS topic + if self._exchange_is_fanout(exchange): + self.fanout.initialise_exchange(exchange) + def drain_events(self, timeout=None, callback=None, **kwargs): """Return a single payload message from one of our queues. @@ -391,9 +413,9 @@ def _resolve_queue_url(self, queue): except KeyError: if self.predefined_queues: raise UndefinedQueueException(( - "Queue with name '{}' must be " - "defined in 'predefined_queues'." - ).format(sqs_qname)) + "Queue with name '{}' must be " + "defined in 'predefined_queues'." + ).format(sqs_qname)) raise DoesNotExistQueueException( f"Queue with name '{sqs_qname}' doesn't exist in SQS" @@ -417,15 +439,17 @@ def _new_queue(self, queue, **kwargs): attributes['FifoQueue'] = 'true' resp = self._create_queue(sqs_qname, attributes) - self._queue_cache[sqs_qname] = resp['QueueUrl'] - return resp['QueueUrl'] + queue_url = self._queue_cache[sqs_qname] = resp["QueueUrl"] + return queue_url def _create_queue(self, queue_name, attributes): """Create an SQS queue with a given name and nominal attributes.""" # Allow specifying additional boto create_queue Attributes # via transport options if self.predefined_queues: - return None + raise UndefinedQueueException( + f"Queue with name '{queue_name}' must be defined in 'predefined_queues'." + ) attributes.update( self.transport_options.get('sqs-creation-attributes') or {}, @@ -443,8 +467,13 @@ def _create_queue(self, queue_name, attributes): return self.sqs(queue=queue_name).create_queue(**create_params) - def _delete(self, queue, *args, **kwargs): - """Delete queue by name.""" + def _delete(self, queue, exchange: str | None = None, *args, **kwargs): + """Delete queue by name. + + :param queue: The queue name + :param exchange: The exchange name, if any + :return: None + """ if self.predefined_queues: return @@ -452,6 +481,12 @@ def _delete(self, queue, *args, **kwargs): self.sqs().delete_queue( QueueUrl=q_url, ) + # If the exchange is a fanout exchange, unsubscribe the queue to the SNS topic + if exchange and self._exchange_is_fanout(exchange): + self.fanout.subscriptions.unsubscribe_queue( + queue_name=queue, exchange_name=exchange + ) + self._queue_cache.pop(queue, None) def _put(self, queue, message, **kwargs): @@ -543,21 +578,20 @@ def _receive_message( """ q_url: str = self._new_queue(queue) client = self.sqs(queue=queue) - - message_system_attribute_names = self.get_message_attributes.get( - 'MessageSystemAttributeNames') or [] - - message_attribute_names = self.get_message_attributes.get( - 'MessageAttributeNames') or [] + msg_attrs = self.get_message_attributes params: dict[str, Any] = { 'QueueUrl': q_url, 'MaxNumberOfMessages': max_number_of_messages, 'WaitTimeSeconds': wait_time_seconds or self.wait_time_seconds, - 'MessageAttributeNames': message_attribute_names, - 'MessageSystemAttributeNames': message_system_attribute_names } + if msg_sys_attrs := msg_attrs.get('MessageSystemAttributeNames'): + params['MessageSystemAttributeNames'] = msg_sys_attrs + + if msg_attrs := msg_attrs.get('MessageAttributeNames'): + params['MessageAttributeNames'] = msg_attrs + return client.receive_message(**params) def _get_bulk(self, queue, @@ -598,12 +632,13 @@ def _get_bulk(self, queue, max_number_of_messages=max_count ) - if resp.get('Messages'): - for m in resp['Messages']: - m['Body'] = AsyncMessage(body=m['Body']).decode() - for msg in self._messages_to_python(resp['Messages'], queue): + if messages := resp.get("Messages"): + for m in messages: + m["Body"] = AsyncMessage(body=m["Body"]).decode() + for msg in self._messages_to_python(messages, queue): self.connection._deliver(msg, queue) return + raise Empty() def _get(self, queue): @@ -614,10 +649,10 @@ def _get(self, queue): max_number_of_messages=1 ) - if resp.get('Messages'): - body = AsyncMessage(body=resp['Messages'][0]['Body']).decode() - resp['Messages'][0]['Body'] = body - return self._messages_to_python(resp['Messages'], queue)[0] + if messages := resp.get("Messages"): + body = AsyncMessage(body=messages[0]["Body"]).decode() + messages[0]["Body"] = body + return self._messages_to_python(messages, queue)[0] raise Empty() def _loop1(self, queue, _=None): @@ -705,7 +740,7 @@ def basic_ack(self, delivery_tag, multiple=False): if exception.response['Error']['Code'] == 'AccessDenied': raise AccessDeniedQueueException( exception.response["Error"]["Message"] - ) + ) super().basic_reject(delivery_tag) else: super().basic_ack(delivery_tag) @@ -716,7 +751,8 @@ def _size(self, queue): c = self.sqs(queue=self.canonical_queue_name(queue)) resp = c.get_queue_attributes( QueueUrl=q_url, - AttributeNames=['ApproximateNumberOfMessages']) + AttributeNames=['ApproximateNumberOfMessages'] + ) return int(resp['Attributes']['ApproximateNumberOfMessages']) def _purge(self, queue): @@ -741,9 +777,34 @@ def close(self): # if "can't set attribute" not in str(exc): # raise - def new_sqs_client(self, region, access_key_id, - secret_access_key, session_token=None): - session = boto3.session.Session( + def new_sqs_client( + self, region, access_key_id, secret_access_key, session_token=None + ): + """Create a new SQS client. + + :param region: The AWS region to use. + :param access_key_id: The AWS access key ID for authenticating with boto. + :param secret_access_key: The AWS secret access key for authenticating with boto. + :param session_token: The AWS session token for authenticating with boto, if required. + :returns: A Boto SQS client. + """ + return self._new_boto_client( + service="sqs", + region=region, + access_key_id=access_key_id, + secret_access_key=secret_access_key, + session_token=session_token, + ) + + def _new_boto_client( + self, + service: _SUPPORTED_BOTO_SERVICES, + region, + access_key_id, + secret_access_key, + session_token=None, + ): + session = boto3.Session( region_name=region, aws_access_key_id=access_key_id, aws_secret_access_key=secret_access_key, @@ -757,35 +818,40 @@ def new_sqs_client(self, region, access_key_id, client_kwargs['endpoint_url'] = self.endpoint_url client_config = self.transport_options.get('client-config') or {} config = Config(**client_config) - return session.client('sqs', config=config, **client_kwargs) + return session.client(service, config=config, **client_kwargs) def sqs(self, queue=None): + # If a queue has been provided, check if the queue has been defined already. Reuse it's client if possible. if queue is not None and self.predefined_queues: - + # Raise if queue is not defined if queue not in self.predefined_queues: raise UndefinedQueueException( - f"Queue with name '{queue}' must be defined" - " in 'predefined_queues'.") + f"Queue with name '{queue}' must be defined in 'predefined_queues'." + ) + q = self.predefined_queues[queue] - if self.transport_options.get('sts_role_arn'): + + # Handle authenticating boto client with tokens + if self.transport_options.get("sts_role_arn"): return self._handle_sts_session(queue, q) - if not self.transport_options.get('sts_role_arn'): - if queue in self._predefined_queue_clients: - return self._predefined_queue_clients[queue] - else: - c = self._predefined_queue_clients[queue] = \ - self.new_sqs_client( - region=q.get('region', self.region), - access_key_id=q.get( - 'access_key_id', self.conninfo.userid), - secret_access_key=q.get( - 'secret_access_key', self.conninfo.password) - ) - return c + # If the queue has already been defined, then return the client for the queue + if c := self._predefined_queue_clients.get(queue): + return c + + # Create client, add it to the queue map and return + c = self._predefined_queue_clients[queue] = self.new_sqs_client( + region=q.get("region", self.region), + access_key_id=q.get("access_key_id", self.conninfo.userid), + secret_access_key=q.get("secret_access_key", self.conninfo.password), + ) + return c + + # If SQS client has been initialised, return it if self._sqs is not None: return self._sqs + # Initialise a new SQS client and return it c = self._sqs = self.new_sqs_client( region=self.region, access_key_id=self.conninfo.userid, @@ -793,19 +859,87 @@ def sqs(self, queue=None): ) return c - def _handle_sts_session(self, queue, q): - region = q.get('region', self.region) - if not hasattr(self, 'sts_expiration'): # STS token - token init + @property + def fanout(self) -> SNS: + """Provides SNS fanout functionality. + + This method returns the fanout instance. If an instance of the fanout class + has not been initialised, then initialise it. + + :returns: An instance of SNS fanout class. + """ + # If an SNS class has not been initialised, then initialise it + if not self._fanout: + self._fanout = SNS(self) + return self._fanout + + def remove_stale_sns_subscriptions(self, exchange_name: str) -> None: + """Removes any stale SNS topic subscriptions. + + This method will check that any SQS subscriptions on the SNS topic are associated with SQS queues. If not, + it will remove the stale subscription. This method will only work if the 'supports_fanout' property is True. + + :param exchange_name: The exchange to check for stale subscriptions + :return: None + """ + if self._exchange_is_fanout(exchange_name): + return self.fanout.subscriptions.cleanup(exchange_name) + return None + + def _handle_sts_session(self, queue: str, q): + """Checks if the STS token needs renewing for SQS. + + :param queue: The queue name + :param q: The queue object + :returns: The SQS client with a refreshed STS token + """ + region = q.get("region", self.region) + + # Check if a token refresh is needed + if self.is_sts_token_refresh_required( + name=queue, + client_map=self._predefined_queue_clients, + expire_time=getattr(self, "sts_expiration", None), + ): return self._new_predefined_queue_client_with_sts_session(queue, region) + + # If token refresh is not required, return existing client + return self._predefined_queue_clients[queue] + + @staticmethod + def is_sts_token_refresh_required( + name: Any, + client_map: dict[str, BaseClient], + expire_time: datetime | None = None, + ) -> bool: + """Checks if the STS token needs renewing. + + This method will check different STS expiry times depending on the service the token was used for. + + :param name: Either the queue name or exchange name + :param client_map: Map of client names to boto3 clients. Either the queue or exchange map + :param expire_time: The datetime when the token expires. + :returns: True if the token needs renewing, False otherwise. + """ + # Get the expiry time of the STS token depending on the service + if not expire_time: # STS token - token init + return True # STS token - refresh if expired - elif self.sts_expiration.replace(tzinfo=None) < datetime.now(timezone.utc).replace(tzinfo=None): - return self._new_predefined_queue_client_with_sts_session(queue, region) - else: # STS token - ruse existing - if queue not in self._predefined_queue_clients: - return self._new_predefined_queue_client_with_sts_session(queue, region) - return self._predefined_queue_clients[queue] + elif ( + expire_time.replace(tzinfo=None) < + datetime.now(timezone.utc).replace(tzinfo=None) + ): + return True + # STS token = refresh if exchange or queue is not in client map + elif not client_map.get(name): + return True + # STS token - reuse existing + else: + return False - def generate_sts_session_token_with_buffer(self, role_arn, token_expiry_seconds, token_buffer_seconds=0): + def generate_sts_session_token_with_buffer( + self, role_arn, token_expiry_seconds, token_buffer_seconds=0 + ): """Generate STS session credentials with an optional expiration buffer. The buffer is only applied if it is less than `token_expiry_seconds` to prevent an expired token. @@ -816,22 +950,29 @@ def generate_sts_session_token_with_buffer(self, role_arn, token_expiry_seconds, return credentials def _new_predefined_queue_client_with_sts_session(self, queue, region): - sts_creds = self.generate_sts_session_token_with_buffer( - self.transport_options.get('sts_role_arn'), - self.transport_options.get('sts_token_timeout', 900), - self.transport_options.get('sts_token_buffer_time', 0), - ) - self.sts_expiration = sts_creds['Expiration'] + # Handle STS token refresh + sts_creds = self.get_sts_credentials() + self.sts_expiration = sts_creds["Expiration"] + + # Get new client and return it c = self._predefined_queue_clients[queue] = self.new_sqs_client( region=region, - access_key_id=sts_creds['AccessKeyId'], - secret_access_key=sts_creds['SecretAccessKey'], - session_token=sts_creds['SessionToken'], + access_key_id=sts_creds["AccessKeyId"], + secret_access_key=sts_creds["SecretAccessKey"], + session_token=sts_creds["SessionToken"], ) return c - def generate_sts_session_token(self, role_arn, token_expiry_seconds): - sts_client = boto3.client('sts') + def get_sts_credentials(self): + return self.generate_sts_session_token_with_buffer( + self.transport_options.get("sts_role_arn"), + self.transport_options.get("sts_token_timeout", 900), + self.transport_options.get("sts_token_buffer_time", 0), + ) + + @staticmethod + def generate_sts_session_token(role_arn: str, token_expiry_seconds: int): + sts_client = boto3.client("sts") sts_policy = sts_client.assume_role( RoleArn=role_arn, RoleSessionName='Celery', @@ -841,26 +982,28 @@ def generate_sts_session_token(self, role_arn, token_expiry_seconds): def asynsqs(self, queue=None): message_system_attribute_names = self.get_message_attributes.get( - 'MessageSystemAttributeNames') + "MessageSystemAttributeNames" + ) message_attribute_names = self.get_message_attributes.get( - 'MessageAttributeNames') + "MessageAttributeNames" + ) if queue is not None and self.predefined_queues: - if queue in self._predefined_queue_async_clients and \ - not hasattr(self, 'sts_expiration'): + if queue in self._predefined_queue_async_clients and not hasattr( + self, "sts_expiration" + ): return self._predefined_queue_async_clients[queue] - if queue not in self.predefined_queues: - raise UndefinedQueueException(( - "Queue with name '{}' must be defined in " - "'predefined_queues'." - ).format(queue)) - q = self.predefined_queues[queue] - c = self._predefined_queue_async_clients[queue] = \ - AsyncSQSConnection( - sqs_connection=self.sqs(queue=queue), - region=q.get('region', self.region), - message_system_attribute_names=message_system_attribute_names, - message_attribute_names=message_attribute_names + + if not (q := self.predefined_queues.get(queue)): + raise UndefinedQueueException( + f"Queue with name '{queue}' must be defined in 'predefined_queues'." + ) + + c = self._predefined_queue_async_clients[queue] = AsyncSQSConnection( + sqs_connection=self.sqs(queue=queue), + region=q.get("region", self.region), + message_system_attribute_names=message_system_attribute_names, + message_attribute_names=message_attribute_names, ) return c @@ -891,7 +1034,12 @@ def visibility_timeout(self): @cached_property def predefined_queues(self): """Map of queue_name to predefined queue settings.""" - return self.transport_options.get('predefined_queues', {}) + return self.transport_options.get("predefined_queues", {}) + + @cached_property + def predefined_exchanges(self): + """Map of exchange_name to predefined SNS client.""" + return self.transport_options.get("predefined_exchanges", {}) @cached_property def queue_name_prefix(self): @@ -899,7 +1047,7 @@ def queue_name_prefix(self): @cached_property def supports_fanout(self): - return False + return self.transport_options.get("supports_fanout", False) @cached_property def region(self): @@ -926,17 +1074,14 @@ def endpoint_url(self): if self.conninfo.port is not None: port = f':{self.conninfo.port}' else: - port = '' - return '{}://{}{}'.format( - scheme, - self.conninfo.hostname, - port - ) + port = "" + return f"{scheme}://{self.conninfo.hostname}{port}" @cached_property def wait_time_seconds(self) -> int: - return self.transport_options.get('wait_time_seconds', - self.default_wait_time_seconds) + return self.transport_options.get( + "wait_time_seconds", self.default_wait_time_seconds + ) @cached_property def sqs_base64_encoding(self): @@ -970,7 +1115,8 @@ def get_message_attributes(self) -> dict[str, Any]: } if isinstance(fetch, list): - message_system_attrs = ['ALL'] if 'ALL'.lower() in [s.lower() for s in fetch] else ( + message_system_attrs = ['ALL'] if 'ALL'.lower() in [s.lower() for s in + fetch] else ( list(set(fetch + [APPROXIMATE_RECEIVE_COUNT])) ) @@ -979,22 +1125,105 @@ def get_message_attributes(self) -> dict[str, Any]: attrs = fetch.get('MessageAttributeNames', None) if isinstance(system, list): - message_system_attrs = ['ALL'] if 'ALL'.lower() in [s.lower() for s in system] else ( + message_system_attrs = ['ALL'] if 'ALL'.lower() in [s.lower() for s in + system] else ( list(set(system + [APPROXIMATE_RECEIVE_COUNT])) ) if isinstance(attrs, list) and attrs: - message_attrs = ['ALL'] if 'ALL'.lower() in [s.lower() for s in attrs] else ( + message_attrs = ['ALL'] if 'ALL'.lower() in [s.lower() for s in + attrs] else ( list(set(attrs)) ) return { - 'MessageAttributeNames': sorted(message_attrs) if message_attrs else [], - 'MessageSystemAttributeNames': ( - sorted(message_system_attrs) if message_system_attrs else [APPROXIMATE_RECEIVE_COUNT] - ) + "MessageAttributeNames": sorted(message_attrs) if message_attrs else [], + "MessageSystemAttributeNames": ( + sorted(message_system_attrs) + if message_system_attrs + else [APPROXIMATE_RECEIVE_COUNT] + ), } + def _put_fanout(self, exchange: str, message: dict, routing_key, **kwargs): + """Add a message to fanout queues by adding a notification to an SNS topic, with subscribed SQS queues. + + :param exchange: The name of the exchange to add the notification to. + :param message: The message to be added. + :param routing_key: The routing key to use for the notification. + :param kwargs: Additional parameters + :return: None + """ + # Extract properties and message attributes from message for SNS parameters + request_params = {} + properties = message.get("properties", {}) + message_attrs = properties.get("message_attributes") + + # If the exchange is a FIFO topic, then add required MessageGroupId and + # MessageDeduplicationId attributes + if exchange.endswith(".fifo"): + request_params["MessageGroupId"] = properties.get( + "MessageGroupId", "default" + ) + request_params["MessageDeduplicationId"] = properties.get( + "MessageDeduplicationId", str(uuid.uuid4()) + ) + + # Add the message to the SNS topic + self.fanout.publish( + exchange_name=exchange, + message=dumps(message), + message_attributes=message_attrs, + request_params=request_params, + ) + + def _subscribe_queue_to_fanout_exchange_if_required(self, queue_name: str) -> None: + """Subscribe the given queue to the SNS topic if this is a fanout exchange. + + :param queue_name: The name of the queue to subscribe. + :return: None + """ + try: + # Get the exchange name for the queue and get the exchange type + exchange_name = self._get_exchange_for_queue(queue_name) + + # If the exchange is a fanout type and the transport supports it, + # subscribe the queue to the topic + if self._exchange_is_fanout(exchange_name): + self.fanout.subscriptions.subscribe_queue( + queue_name=queue_name, exchange_name=exchange_name + ) + + except UndefinedQueueException as e: + logger.debug( + f"Not subscribing queue '{queue_name}' to fanout exchange: {e}" + ) + + def _exchange_is_fanout(self, exchange_name: str) -> bool: + """Check if the given exchange is a fanout type. + + :param exchange_name: The name of the exchange to check. + :return: True if the exchange is a fanout type and the transport supports it, + False otherwise. + """ + try: + exchange_type = self.state.exchanges[exchange_name]["type"] + return exchange_type == "fanout" and self.supports_fanout + except KeyError: + return False + + def _get_exchange_for_queue(self, queue_name: str) -> str: + """Get the exchange name for the given queue. + + :param queue_name: The name of the queue to get the exchange for. + :return: The name of the exchange for the given queue. + :raises UndefinedQueueException: If the queue has not been defined. + """ + try: + return list(self.state.queue_index[queue_name])[0].exchange + except (KeyError, IndexError): + raise UndefinedQueueException(f"Queue '{queue_name}' has not been defined.") + # ————————————————————————————————————————————————————————————— # _message_to_python helper methods (extracted for testing/readability) # ————————————————————————————————————————————————————————————— @@ -1076,7 +1305,7 @@ def _envelope_payload(self, payload, raw_text, message, q_url): # add SQS metadata di.update({ 'sqs_message': message, - 'sqs_queue': q_url, + 'sqs_queue': q_url, }) props['delivery_tag'] = message['ReceiptHandle'] @@ -1166,10 +1395,10 @@ class Transport(virtual.Transport): default_port = None connection_errors = ( virtual.Transport.connection_errors + - (exceptions.BotoCoreError, socket.error) + (aws_exceptions.BotoCoreError, socket.error) ) channel_errors = ( - virtual.Transport.channel_errors + (exceptions.BotoCoreError,) + virtual.Transport.channel_errors + (aws_exceptions.BotoCoreError,) ) driver_type = 'sqs' driver_name = 'sqs' diff --git a/kombu/transport/SQS/exceptions.py b/kombu/transport/SQS/exceptions.py new file mode 100644 index 0000000000..aa9a619c42 --- /dev/null +++ b/kombu/transport/SQS/exceptions.py @@ -0,0 +1,27 @@ +"""AWS SQS and SNS exceptions.""" + +from __future__ import annotations + + +class UndefinedQueueException(Exception): + """Predefined queues are being used and an undefined queue was used.""" + + +class UndefinedExchangeException(Exception): + """Predefined exchanges are being used and an undefined exchange/SNS topic was used.""" + + +class InvalidQueueException(Exception): + """Predefined queues are being used and configuration is not valid.""" + + +class AccessDeniedQueueException(Exception): + """Raised when access to the AWS queue is denied. + + This may occur if the permissions are not correctly set or the + credentials are invalid. + """ + + +class DoesNotExistQueueException(Exception): + """The specified queue doesn't exist.""" diff --git a/t/unit/transport/SQS/__init__.py b/t/unit/transport/SQS/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/t/unit/transport/SQS/conftest.py b/t/unit/transport/SQS/conftest.py new file mode 100644 index 0000000000..2b6d2351e6 --- /dev/null +++ b/t/unit/transport/SQS/conftest.py @@ -0,0 +1,151 @@ +"""Testing module for the kombu.transport.SQS package. + +NOTE: The SQSQueueMock and SQSConnectionMock classes originally come from +http://github.com/pcsforeducation/sqs-mock-python. They have been patched +slightly. +""" +from __future__ import annotations + +import datetime +from unittest.mock import MagicMock, patch + +import pytest + +from kombu import Connection +from kombu.transport.SQS.SNS import SNS, _SnsSubscription + +boto3 = pytest.importorskip('boto3') + + +from kombu.transport import SQS # noqa + +SQS_Channel_sqs = SQS.Channel.sqs + +example_predefined_queues = { + 'queue-1': { + 'url': 'https://sqs.us-east-1.amazonaws.com/xxx/queue-1', + 'access_key_id': 'a', + 'secret_access_key': 'b', + 'backoff_tasks': ['svc.tasks.tasks.task1'], + 'backoff_policy': {1: 10, 2: 20, 3: 40, 4: 80, 5: 320, 6: 640} + }, + 'queue-2': { + 'url': 'https://sqs.us-east-1.amazonaws.com/xxx/queue-2', + 'access_key_id': 'c', + 'secret_access_key': 'd', + }, + "queue-3.fifo": { + "url": "https://sqs.us-east-1.amazonaws.com/xxx/queue-3.fifo", + "access_key_id": "e", + "secret_access_key": "f", + }, +} + +example_predefined_exchanges = { + "exchange-1": { + "arn": "arn:aws:sns:us-east-1:xxx:exchange-1", + "access_key_id": "a", + "secret_access_key": "b", + }, + "exchange-2.fifo": { + "arn": "arn:aws:sns:us-east-1:xxx:exchange-2", + "access_key_id": "a", + "secret_access_key": "b", + }, +} + + +@pytest.fixture +def connection_fixture(): + return Connection( + transport=SQS.Transport, + transport_options={ + "predefined_queues": example_predefined_queues, + }, + ) + + +@pytest.fixture +def channel_fixture(connection_fixture) -> SQS.Channel: + chan = connection_fixture.channel() + chan.region = "some-aws-region" + return chan + + +@pytest.fixture +def mock_sqs(): + with patch("kombu.transport.SQS.Channel.sqs") as mock: + mock.name = "Sqs client mock" + yield mock + + +@pytest.fixture +def mock_fanout(): + with patch("kombu.transport.SQS.Channel.fanout") as mock: + yield mock + + +class _BotoStsClientMock: + @staticmethod + def assume_role(RoleArn, RoleSessionName, DurationSeconds, *args, **kwargs): + return { + 'AssumedRoleUser': { + 'Arn': RoleArn, + 'AssumedRoleId': 'ARO123EXAMPLE123:Bob', + }, + 'Credentials': { + 'AccessKeyId': 'AKIAIOSFODNN7EXAMPLE', + 'Expiration': datetime.datetime.now() + datetime.timedelta(seconds=DurationSeconds), + 'SecretAccessKey': 'wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY', + 'SessionToken': 'AQoDYXdzEPT//////////wEXAMPLEtc764bNrC9SAPBSM22wDOk4x4HIZ8j4FZTwdQWLWs' + 'KWHGBuFqwAeMicRXmxfpSPfIeoIYRqTflfKD8YUuwthAx7mSEI/qkPpKPi/kMcGdQrmGdee' + 'hM4IC1NtBmUpp2wUE8phUZampKsburEDy0KPkyQDYwT7WZ0wq5VSXDvp75YU9HFvlRd8Tx6q6' + 'fE8YQcHNVXAkiY9q6d+xo0rKwT38xVqr7ZD0u0iPPkUL64lIZbqBAz+scqKmlzm8FDrypNC9Yj' + 'c8fPOLn9FX9KSYvKTr4rvx3iSIlTJabIQwj2ICCR/oLxBA==', + }, + 'PackedPolicySize': 8, + 'ResponseMetadata': { + "key": "value" + }, + } + + +@pytest.fixture +def mock_boto_client(): + def _client_builder(service_name: str): + if service_name == "sts": + return _BotoStsClientMock() + else: + return MagicMock(name=f"Boto3 '{service_name}' mock") + + with patch("kombu.transport.SQS.boto3.client", wraps=_client_builder) as mock: + yield mock + + +@pytest.fixture +def mock_new_sqs_client(): + with patch("kombu.transport.SQS.Channel.new_sqs_client") as mock: + yield mock + + +@pytest.fixture +def sns_fanout(channel_fixture): + inst = SNS(channel_fixture) + + # Clear previous class vars + inst._predefined_clients = {} + inst._topic_arn_cache = {} + inst._exchange_topic_cache = {} + + return inst + + +@pytest.fixture +def sns_subscription(sns_fanout): + inst = _SnsSubscription(sns_fanout) + + # Clear previous class vars + inst._queue_arn_cache = {} + inst._subscription_arn_cache = {} + + return inst diff --git a/t/unit/transport/test_SQS.py b/t/unit/transport/SQS/test_SQS.py similarity index 59% rename from t/unit/transport/test_SQS.py rename to t/unit/transport/SQS/test_SQS.py index 4456de8445..8c8dd64393 100644 --- a/t/unit/transport/test_SQS.py +++ b/t/unit/transport/SQS/test_SQS.py @@ -7,16 +7,20 @@ from __future__ import annotations import base64 +import logging import os import random import string from datetime import datetime, timedelta, timezone from queue import Empty -from unittest.mock import Mock, patch +from unittest.mock import Mock, call, patch import pytest from kombu import Connection, Exchange, Queue, messaging +from kombu.transport.SQS import UndefinedQueueException, maybe_int + +from .conftest import example_predefined_exchanges, example_predefined_queues boto3 = pytest.importorskip('boto3') @@ -27,27 +31,6 @@ SQS_Channel_sqs = SQS.Channel.sqs -example_predefined_queues = { - 'queue-1': { - 'url': 'https://sqs.us-east-1.amazonaws.com/xxx/queue-1', - 'access_key_id': 'a', - 'secret_access_key': 'b', - 'backoff_tasks': ['svc.tasks.tasks.task1'], - 'backoff_policy': {1: 10, 2: 20, 3: 40, 4: 80, 5: 320, 6: 640} - }, - 'queue-2': { - 'url': 'https://sqs.us-east-1.amazonaws.com/xxx/queue-2', - 'access_key_id': 'c', - 'secret_access_key': 'd', - }, - 'queue-3.fifo': { - 'url': 'https://sqs.us-east-1.amazonaws.com/xxx/queue-3.fifo', - 'access_key_id': 'e', - 'secret_access_key': 'f', - } -} - - class SQSMessageMock: def __init__(self): """ @@ -163,6 +146,16 @@ def delete_queue(self, QueueUrl=None): del self._queues[queue_name] +class test_MaybeInt: + @pytest.mark.parametrize("input_value", [100, "100", 100.0]) + def test_working(self, input_value): + assert maybe_int(input_value) == 100 + + @pytest.mark.parametrize("input_value", [None, "hello", Mock(), "100.0"]) + def test_nan(self, input_value): + assert maybe_int(input_value) is input_value + + class test_Channel: def handleMessageCallback(self, message): @@ -240,12 +233,14 @@ def test_init(self): def test_region(self): _environ = dict(os.environ) - # when the region is unspecified - connection = Connection(transport=SQS.Transport) - channel = connection.channel() - assert channel.transport_options.get('region') is None - # the default region is us-east-1 - assert channel.region == 'us-east-1' + # when the region is unspecified, and Boto3 also does not have a region set + with patch("kombu.transport.SQS.boto3.Session") as boto_session_mock: + boto_session_mock().region_name = None + connection = Connection(transport=SQS.Transport) + channel = connection.channel() + assert channel.transport_options.get("region") is None + # the default region is us-east-1 + assert channel.region == "us-east-1" # when boto3 picks a region os.environ['AWS_DEFAULT_REGION'] = 'us-east-2' @@ -284,20 +279,85 @@ def test_none_hostname_persists(self): def test_entity_name(self): assert self.channel.entity_name('foo') == 'foo' assert self.channel.entity_name('foo.bar-baz*qux_quux') == \ - 'foo-bar-baz_qux_quux' + 'foo-bar-baz_qux_quux' assert self.channel.entity_name('abcdef.fifo') == 'abcdef.fifo' + @patch('kombu.transport.virtual.base.Channel.basic_consume') + @patch('kombu.transport.SQS.Channel._noack_queues') + def test_basic_consume_no_ack(self, noack_queues_mock, basic_consume_mock): + # Arrange + channel = self.connection.channel() + + # Act + channel.basic_consume( + self.queue_name, + no_ack=True, + callback=self.handleMessageCallback, + consumer_tag='unittest' + ) + + # Assert + assert noack_queues_mock.add.call_args_list == [call(self.queue_name)] + assert basic_consume_mock.call_args_list == [ + call(self.queue_name, True, self.handleMessageCallback, 'unittest') + ] + + @patch('kombu.transport.virtual.base.Channel.basic_cancel') + @patch('kombu.transport.SQS.Channel._noack_queues') + def test_basic_cancel(self, noack_queues_mock, basic_consume_mock): + # Arrange + consumer_tag = 'unittest' + channel = self.connection.channel() + channel._consumers.add(consumer_tag) + channel._tag_to_queue[consumer_tag] = 'test-queue' + + # Act + channel.basic_cancel(consumer_tag) + + # Assert + assert noack_queues_mock.discard.call_args_list == [call("test-queue")] + assert basic_consume_mock.call_args_list == [call(consumer_tag)] + def test_resolve_queue_url(self): queue_name = 'unittest_queue' assert self.sqs_conn_mock._queues[queue_name].url == \ - self.channel._resolve_queue_url(queue_name) + self.channel._resolve_queue_url(queue_name) def test_new_queue(self): queue_name = 'new_unittest_queue' - self.channel._new_queue(queue_name) - assert queue_name in self.sqs_conn_mock._queues.keys() - # For cleanup purposes, delete the queue and the queue file - self.channel._delete(queue_name) + try: + self.channel._new_queue(queue_name) + assert queue_name in self.sqs_conn_mock._queues.keys() + finally: + # For cleanup purposes, delete the queue and the queue file + self.channel._delete(queue_name) + + def test_new_fifo_queue(self): + queue_name = 'new_unittest_queue.fifo' + try: + self.channel._new_queue(queue_name) + + queue: QueueMock = self.sqs_conn_mock._queues[queue_name] + assert isinstance(queue, QueueMock) + assert queue.url == 'https://sqs.us-east-1.amazonaws.com/xxx/' + queue_name + assert queue.creation_attributes == {'VisibilityTimeout': str(self.channel.visibility_timeout), + 'FifoQueue': 'true'} + + finally: + # For cleanup purposes, delete the queue and the queue file + self.channel._delete(queue_name) + + def test_create_queue_with_predefined_queues(self): + queue_name = 'new_unittest_queue' + try: + with pytest.raises( + UndefinedQueueException, + match=f"Queue with name '{queue_name}' must be defined in 'predefined_queues'." + ): + self.channel.predefined_queues = example_predefined_queues + self.channel._create_queue(queue_name, {}) + finally: + self.channel.predefined_queues = set() def test_new_queue_custom_creation_attributes(self): self.connection.transport_options['sqs-creation-attributes'] = { @@ -370,6 +430,63 @@ def test_delete(self): assert queue_name not in self.channel._queue_cache assert queue_name not in self.sqs_conn_mock._queues + def test_delete_with_predefined_queues(self): + queue_name = 'new_unittest_queue' + try: + self.channel.predefined_queues = example_predefined_queues + assert self.channel._delete(queue_name) is None + finally: + self.channel.predefined_queues = set() + + def test_delete_with_no_exchange( + self, mock_fanout, channel_fixture, connection_fixture + ): + # Arrange + queue_name = "queue-1" + channel_fixture.supports_fanout = True + channel_fixture.predefined_queues = {} + self.sqs_conn_mock._queues[queue_name] = QueueMock( + url="https://sqs.us-east-1.amazonaws.com/xxx/queue-1" + ) + channel_fixture._new_queue(queue_name) + assert queue_name in channel_fixture._queue_cache + + # Act + channel_fixture._delete(queue_name) + + # Assert + assert mock_fanout.unsubscribe_queue.call_count == 0 + assert queue_name not in channel_fixture._queue_cache + assert queue_name not in self.sqs_conn_mock._queues + + def test_delete_with_fanout_exchange( + self, mock_fanout, channel_fixture, connection_fixture + ): + # Arrange + queue_name = "queue-1" + channel_fixture.supports_fanout = True + channel_fixture.predefined_queues = {} + self.sqs_conn_mock._queues[queue_name] = QueueMock( + url="https://sqs.us-east-1.amazonaws.com/xxx/queue-1" + ) + + # Declare fanout exchange and queue + exchange_name = "test_SQS_fanout" + exchange = Exchange(exchange_name, type="fanout") + queue = Queue("queue-1", exchange) + queue(channel_fixture).declare() + assert queue_name in channel_fixture._queue_cache + + # Act + channel_fixture._delete(queue_name, exchange_name) + + # Assert + assert mock_fanout.subscriptions.unsubscribe_queue.call_args_list == [ + call(queue_name="queue-1", exchange_name="test_SQS_fanout") + ] + assert queue_name not in channel_fixture._queue_cache + assert queue_name not in self.sqs_conn_mock._queues + def test_get_from_sqs(self): # Test getting a single message message = 'my test message' @@ -421,17 +538,17 @@ def test_get_bulk_raises_empty(self): # json/dict (encoded and raw) ( - b'{"id": "4cc7438e-afd4-4f8f-a2f3-f46567e7ca77","task": "celery.task.PingTask",' - b'"args": [],"kwargs": {},"retries": 0,"eta": "2009-11-17T12:30:56.527191"}', - b'{"id": "4cc7438e-afd4-4f8f-a2f3-f46567e7ca77","task": "celery.task.PingTask",' - b'"args": [],"kwargs": {},"retries": 0,"eta": "2009-11-17T12:30:56.527191"}' + b'{"id": "4cc7438e-afd4-4f8f-a2f3-f46567e7ca77","task": "celery.task.PingTask",' + b'"args": [],"kwargs": {},"retries": 0,"eta": "2009-11-17T12:30:56.527191"}', + b'{"id": "4cc7438e-afd4-4f8f-a2f3-f46567e7ca77","task": "celery.task.PingTask",' + b'"args": [],"kwargs": {},"retries": 0,"eta": "2009-11-17T12:30:56.527191"}' ), ( - base64.b64encode( + base64.b64encode( + b'{"id": "4cc7438e-afd4-4f8f-a2f3-f46567e7ca77","task": "celery.task.PingTask",' + b'"args": [],"kwargs": {},"retries": 0,"eta": "2009-11-17T12:30:56.527191"}'), b'{"id": "4cc7438e-afd4-4f8f-a2f3-f46567e7ca77","task": "celery.task.PingTask",' - b'"args": [],"kwargs": {},"retries": 0,"eta": "2009-11-17T12:30:56.527191"}'), - b'{"id": "4cc7438e-afd4-4f8f-a2f3-f46567e7ca77","task": "celery.task.PingTask",' - b'"args": [],"kwargs": {},"retries": 0,"eta": "2009-11-17T12:30:56.527191"}' + b'"args": [],"kwargs": {},"retries": 0,"eta": "2009-11-17T12:30:56.527191"}' ) ], @@ -478,23 +595,24 @@ def test_delete_message(self, queue_name, message, new_q_url, monkeypatch): [ # No 'properties' ( - {}, - "raw string", - {"ReceiptHandle": "RH"}, - "http://queue.url", - True, + {}, + "raw string", + {"ReceiptHandle": "RH"}, + "http://queue.url", + True, ), # Existing 'properties' ( - {"properties": {"delivery_info": {"foo": "bar"}}}, - "ignored", - {"ReceiptHandle": "TAG"}, - "https://q.url", - False, + {"properties": {"delivery_info": {"foo": "bar"}}}, + "ignored", + {"ReceiptHandle": "TAG"}, + "https://q.url", + False, ), ], ) - def test_envelope_payload(self, initial_payload, raw_text, message, q_url, expect_body): + def test_envelope_payload(self, initial_payload, raw_text, message, q_url, + expect_body): payload = initial_payload.copy() result = self.channel._envelope_payload(payload, raw_text, message, q_url) @@ -530,14 +648,16 @@ def test_messages_to_python(self): # Get the messages now kombu_messages = [] for m in self.sqs_conn_mock.receive_message( - QueueUrl=q_url, - MaxNumberOfMessages=kombu_message_count)['Messages']: + QueueUrl=q_url, + MaxNumberOfMessages=kombu_message_count + )['Messages']: m['Body'] = Message(body=m['Body']).decode() kombu_messages.append(m) json_messages = [] for m in self.sqs_conn_mock.receive_message( - QueueUrl=q_url, - MaxNumberOfMessages=json_message_count)['Messages']: + QueueUrl=q_url, + MaxNumberOfMessages=json_message_count + )['Messages']: m['Body'] = Message(body=m['Body']).decode() json_messages.append(m) @@ -683,110 +803,185 @@ def test_get_async(self): 'WaitTimeSeconds': self.channel.wait_time_seconds, } assert get_list_args[3] == \ - self.channel.sqs().get_queue_url(self.queue_name).url + self.channel.sqs().get_queue_url(self.queue_name).url assert get_list_kwargs['parent'] == self.queue_name assert get_list_kwargs['protocol_params'] == { 'json': {'MessageSystemAttributeNames': ['ApproximateReceiveCount']}, 'query': {'MessageSystemAttributeName.1': 'ApproximateReceiveCount'}, } + @patch('kombu.transport.SQS.AsyncSQSConnection') + def test_asynsqs_with_predefined_queue_creates_queue_existing_client(self, mock_async_sqs): + # Arrange + queue_name = 'queue-1' + + mock_async_instance = Mock(name='async_sqs_instance') + mock_async_sqs.side_effect = AssertionError("This should not have been called") + + self.channel.predefined_queues = example_predefined_queues + self.channel._predefined_queue_async_clients[queue_name] = mock_async_instance + + # Act + result = self.channel.asynsqs(queue=queue_name) + + # Assert + assert result is mock_async_instance + assert mock_async_sqs.call_count == 0 + + @patch('kombu.transport.SQS.AsyncSQSConnection') + def test_asynsqs_with_predefined_queue_creates_queue_no_existing_client(self, mock_async_sqs): + # Arrange + queue_name = 'queue-1' + expected_queue_mock = self.channel.sqs(queue_name) + + mock_async_instance = Mock(name='async_sqs_instance') + mock_async_sqs.return_value = mock_async_instance + + self.channel.predefined_queues = example_predefined_queues + self.channel._predefined_queue_async_clients = {} + + # Act + result = self.channel.asynsqs(queue=queue_name) + + # Assert + assert result is mock_async_instance + assert mock_async_sqs.call_args_list == [ + call( + sqs_connection=expected_queue_mock, + region='us-east-1', + message_system_attribute_names=['ApproximateReceiveCount'], + message_attribute_names=[] + ) + ] + + @patch('kombu.transport.SQS.AsyncSQSConnection') + def test_asynsqs_with_defined_queues_but_missing(self, mock_async_sqs): + # Arrange + queue_name = 'queue-6' + + self.channel.predefined_queues = example_predefined_queues + + # Act + with pytest.raises( + UndefinedQueueException, + match="Queue with name 'queue-6' must be defined in 'predefined_queues'." + ): + self.channel.asynsqs(queue=queue_name) + + # Assert + assert mock_async_sqs.call_count == 0 + @pytest.mark.parametrize('fetch_attributes,expected', [ # as a list for backwards compatibility ( - None, - {'message_system_attribute_names': ['ApproximateReceiveCount'], 'message_attribute_names': []} + None, + {'message_system_attribute_names': ['ApproximateReceiveCount'], + 'message_attribute_names': []} ), ( - 'incorrect_value', - {'message_system_attribute_names': ['ApproximateReceiveCount'], 'message_attribute_names': []} + 'incorrect_value', + {'message_system_attribute_names': ['ApproximateReceiveCount'], + 'message_attribute_names': []} ), ( - [], - {'message_system_attribute_names': ['ApproximateReceiveCount'], 'message_attribute_names': []} + [], + {'message_system_attribute_names': ['ApproximateReceiveCount'], + 'message_attribute_names': []} ), ( - ['ALL'], - {'message_system_attribute_names': ['ALL'], 'message_attribute_names': []} + ['ALL'], + {'message_system_attribute_names': ['ALL'], + 'message_attribute_names': []} ), ( - ['SenderId', 'SentTimestamp'], - { - 'message_system_attribute_names': ['SenderId', 'ApproximateReceiveCount', 'SentTimestamp'], - 'message_attribute_names': [] - } + ['SenderId', 'SentTimestamp'], + { + 'message_system_attribute_names': ['SenderId', + 'ApproximateReceiveCount', + 'SentTimestamp'], + 'message_attribute_names': [] + } ), # As a dict using only System Attributes ( - {'MessageSystemAttributeNames': ['All']}, - { - 'message_system_attribute_names': ['ALL'], - 'message_attribute_names': [] - } + {'MessageSystemAttributeNames': ['All']}, + { + 'message_system_attribute_names': ['ALL'], + 'message_attribute_names': [] + } ), ( - {'MessageSystemAttributeNames': ['SenderId', 'SentTimestamp']}, - { - 'message_system_attribute_names': ['SenderId', 'ApproximateReceiveCount', 'SentTimestamp'], - 'message_attribute_names': [] - } + {'MessageSystemAttributeNames': ['SenderId', 'SentTimestamp']}, + { + 'message_system_attribute_names': ['SenderId', + 'ApproximateReceiveCount', + 'SentTimestamp'], + 'message_attribute_names': [] + } ), ( - {'MessageSystemAttributeNames_BAD_KEY': ['That', 'This']}, - { - 'message_system_attribute_names': ['ApproximateReceiveCount'], - 'message_attribute_names': [] - } + {'MessageSystemAttributeNames_BAD_KEY': ['That', 'This']}, + { + 'message_system_attribute_names': ['ApproximateReceiveCount'], + 'message_attribute_names': [] + } ), # As a dict using only Message Attributes ( - {'MessageAttributeNames': ['All']}, - { - 'message_system_attribute_names': ['ApproximateReceiveCount'], - 'message_attribute_names': ["ALL"] - } + {'MessageAttributeNames': ['All']}, + { + 'message_system_attribute_names': ['ApproximateReceiveCount'], + 'message_attribute_names': ["ALL"] + } ), ( - {'MessageAttributeNames': ['CustomProp', 'CustomProp2']}, - { - 'message_system_attribute_names': ['ApproximateReceiveCount'], - 'message_attribute_names': ['CustomProp', 'CustomProp2'] - } + {'MessageAttributeNames': ['CustomProp', 'CustomProp2']}, + { + 'message_system_attribute_names': ['ApproximateReceiveCount'], + 'message_attribute_names': ['CustomProp', 'CustomProp2'] + } ), ( - {'MessageAttributeNames_BAD_KEY': ['That', 'This']}, - { - 'message_system_attribute_names': ['ApproximateReceiveCount'], - 'message_attribute_names': [] - } + {'MessageAttributeNames_BAD_KEY': ['That', 'This']}, + { + 'message_system_attribute_names': ['ApproximateReceiveCount'], + 'message_attribute_names': [] + } ), # all together now... ( - { - 'MessageSystemAttributeNames': ['SenderId', 'SentTimestamp'], - 'MessageAttributeNames': ['CustomProp', 'CustomProp2']}, - { - 'message_system_attribute_names': ['SenderId', 'SentTimestamp', 'ApproximateReceiveCount'], - 'message_attribute_names': ['CustomProp', 'CustomProp2'] - } + { + 'MessageSystemAttributeNames': ['SenderId', 'SentTimestamp'], + 'MessageAttributeNames': ['CustomProp', 'CustomProp2']}, + { + 'message_system_attribute_names': ['SenderId', 'SentTimestamp', + 'ApproximateReceiveCount'], + 'message_attribute_names': ['CustomProp', 'CustomProp2'] + } ), ]) @pytest.mark.usefixtures('hub') def test_fetch_message_attributes(self, fetch_attributes, expected): - self.connection.transport_options['fetch_message_attributes'] = fetch_attributes # type: ignore + self.connection.transport_options[ + 'fetch_message_attributes'] = fetch_attributes # type: ignore async_sqs_conn = self.channel.asynsqs(self.queue_name) - assert async_sqs_conn.message_system_attribute_names == sorted(expected['message_system_attribute_names']) - assert async_sqs_conn.message_attribute_names == expected['message_attribute_names'] + assert async_sqs_conn.message_system_attribute_names == sorted( + expected['message_system_attribute_names']) + assert async_sqs_conn.message_attribute_names == expected[ + 'message_attribute_names'] @pytest.mark.usefixtures('hub') def test_fetch_message_attributes_does_not_exist(self): self.connection.transport_options = {} async_sqs_conn = self.channel.asynsqs(self.queue_name) - assert async_sqs_conn.message_system_attribute_names == ['ApproximateReceiveCount'] + assert async_sqs_conn.message_system_attribute_names == [ + 'ApproximateReceiveCount'] assert async_sqs_conn.message_attribute_names == [] def test_drain_events_with_empty_list(self): def mock_can_consume(): return False + self.channel.qos.can_consume = mock_can_consume with pytest.raises(Empty): self.channel.drain_events() @@ -805,6 +1000,7 @@ def test_drain_events_with_prefetch_5(self): def on_message_delivered(message, queue): current_delivery_tag[0] += 1 self.channel.qos.append(message, current_delivery_tag[0]) + self.channel.connection._deliver.side_effect = on_message_delivered # Now, generate all the messages @@ -839,6 +1035,7 @@ def test_drain_events_with_prefetch_none(self): def on_message_delivered(message, queue): current_delivery_tag[0] += 1 self.channel.qos.append(message, current_delivery_tag[0]) + self.channel.connection._deliver.side_effect = on_message_delivered # Now, generate all the messages @@ -1073,7 +1270,8 @@ def test_predefined_queues_backoff_policy(self): channel = connection.channel() def apply_backoff_policy( - queue_name, delivery_tag, retry_policy, backoff_tasks): + queue_name, delivery_tag, retry_policy, backoff_tasks + ): return None mock_apply_policy = Mock(side_effect=apply_backoff_policy) @@ -1100,14 +1298,6 @@ def test_predefined_queues_change_visibility_timeout(self): }) channel = connection.channel() - def extract_task_name_and_number_of_retries(delivery_tag): - return 'svc.tasks.tasks.task1', 2 - - mock_extract_task_name_and_number_of_retries = Mock( - side_effect=extract_task_name_and_number_of_retries) - channel.qos.extract_task_name_and_number_of_retries = \ - mock_extract_task_name_and_number_of_retries - queue_name = "queue-1" exchange = Exchange('test_SQS', type='direct') @@ -1116,6 +1306,8 @@ def extract_task_name_and_number_of_retries(delivery_tag): message_mock = Mock() message_mock.delivery_info = {'routing_key': queue_name} + message_mock.headers = {"task": "svc.tasks.tasks.task1"} + message_mock.properties = {"delivery_info": {"sqs_message": {"Attributes": {"ApproximateReceiveCount": 2}}}} channel.qos._delivered['test_message_id'] = message_mock channel.sqs = Mock() @@ -1149,7 +1341,7 @@ def test_predefined_queues_put_to_fifo_queue(self): sqs_queue_mock.send_message.assert_called_once() assert 'MessageGroupId' in sqs_queue_mock.send_message.call_args[1] assert 'MessageDeduplicationId' in \ - sqs_queue_mock.send_message.call_args[1] + sqs_queue_mock.send_message.call_args[1] def test_predefined_queues_put_to_queue(self): connection = Connection(transport=SQS.Transport, transport_options={ @@ -1176,20 +1368,20 @@ def test_predefined_queues_put_to_queue(self): assert sqs_queue_mock.send_message.call_args[1]['DelaySeconds'] == 10 @pytest.mark.parametrize('predefined_queues', ( - { - 'invalid-fifo-queue-name': { - 'url': 'https://sqs.us-east-1.amazonaws.com/xxx/queue.fifo', - 'access_key_id': 'a', - 'secret_access_key': 'b' - } - }, - { - 'standard-queue.fifo': { - 'url': 'https://sqs.us-east-1.amazonaws.com/xxx/queue', - 'access_key_id': 'a', - 'secret_access_key': 'b' + { + 'invalid-fifo-queue-name': { + 'url': 'https://sqs.us-east-1.amazonaws.com/xxx/queue.fifo', + 'access_key_id': 'a', + 'secret_access_key': 'b' + } + }, + { + 'standard-queue.fifo': { + 'url': 'https://sqs.us-east-1.amazonaws.com/xxx/queue', + 'access_key_id': 'a', + 'secret_access_key': 'b' + } } - } )) def test_predefined_queues_invalid_configuration(self, predefined_queues): connection = Connection(transport=SQS.Transport, transport_options={ @@ -1245,7 +1437,8 @@ def test_sts_new_session_with_buffer_time(self): mock_new_sqs_client = Mock() channel.new_sqs_client = mock_new_sqs_client - expiration_time = datetime.now(timezone.utc) + timedelta(seconds=sts_token_timeout) + expiration_time = datetime.now(timezone.utc) + timedelta( + seconds=sts_token_timeout) mock_generate_sts_session_token.side_effect = [ { @@ -1262,7 +1455,8 @@ def test_sts_new_session_with_buffer_time(self): # Assert mock_generate_sts_session_token.assert_called_once() - assert channel.sts_expiration == expiration_time - timedelta(seconds=sts_token_buffer_time) + assert channel.sts_expiration == expiration_time - timedelta( + seconds=sts_token_buffer_time) def test_sts_session_expired(self): # Arrange @@ -1313,7 +1507,8 @@ def test_sts_session_expired_with_buffer_time(self): mock_new_sqs_client = Mock() channel.new_sqs_client = mock_new_sqs_client - expiration_time = datetime.now(timezone.utc) + timedelta(seconds=sts_token_timeout) + expiration_time = datetime.now(timezone.utc) + timedelta( + seconds=sts_token_timeout) mock_generate_sts_session_token.side_effect = [ { @@ -1330,7 +1525,8 @@ def test_sts_session_expired_with_buffer_time(self): # Assert mock_generate_sts_session_token.assert_called_once() - assert channel.sts_expiration == expiration_time - timedelta(seconds=sts_token_buffer_time) + assert channel.sts_expiration == expiration_time - timedelta( + seconds=sts_token_buffer_time) def test_sts_session_not_expired(self): # Arrange @@ -1383,12 +1579,15 @@ def test_sts_session_with_multiple_predefined_queues(self): channel.generate_sts_session_token = mock_generate_sts_session_token # Act - sqs(queue='queue-1') - sqs(queue='queue-2') + sqs(queue="queue-1") + sqs(queue="queue-2") + + # Call queue a second time to check new STS token is not generated + sqs(queue="queue-2") # Assert - mock_generate_sts_session_token.assert_called() - mock_new_sqs_client.assert_called() + assert mock_generate_sts_session_token.call_count == 2 + assert mock_new_sqs_client.call_count == 2 def test_message_attribute(self): message = 'my test message' @@ -1396,8 +1595,492 @@ def test_message_attribute(self): 'Attribute1': {'DataType': 'String', 'StringValue': 'STRING_VALUE'} } - ) + ) output_message = self.queue(self.channel).get() assert message == output_message.payload # It's not propagated to the properties - assert 'message_attributes' not in output_message.properties + assert "message_attributes" not in output_message.properties + + def test_exchange_is_fanout_no_defined_queues(self, channel_fixture): + # Act & assert + assert channel_fixture._exchange_is_fanout("queue-1") is False + + def test_exchange_is_fanout_with_fanout_exchange( + self, channel_fixture: SQS.Channel, mock_fanout + ): + # Arrange + channel_fixture.supports_fanout = True + + # One fanout exchange and queue + exchange = Exchange("test_SQS_fanout", type="fanout") + queue = Queue("queue-1", exchange) + queue(channel_fixture).declare() + + # One direct exchange and queue + exchange = Exchange("test_SQS", type="direct") + queue = Queue("queue-2", exchange) + queue(channel_fixture).declare() + + # Act & Assert + assert channel_fixture._exchange_is_fanout("test_SQS_fanout") is True + assert channel_fixture._exchange_is_fanout("test_SQS") is False + + def test_get_exchange_for_queue_with_defined_queue( + self, channel_fixture: SQS.Channel, mock_fanout + ): + # Arrange + exchange = Exchange("test_SQS_fanout", type="fanout") + queue = Queue("queue-1", exchange) + queue(channel_fixture).declare() + + # Act + result = channel_fixture._get_exchange_for_queue("queue-1") + + # Assert + assert result == "test_SQS_fanout" + + def test_get_exchange_for_queue_with_queue_not_defined( + self, channel_fixture: SQS.Channel, mock_fanout + ): + # Arrange + exchange = Exchange("test_SQS_fanout", type="fanout") + queue = Queue("queue-1", exchange) + queue(channel_fixture).declare() + + # Act + with pytest.raises( + UndefinedQueueException, match="Queue 'queue-2' has not been defined." + ): + channel_fixture._get_exchange_for_queue("queue-2") + + def test_remove_stale_sns_subscriptions_no_defined_queues( + self, mock_fanout, channel_fixture + ): + # Arrange + mock_fanout.subscriptions.cleanup.return_value = "This should not be returned" + + # Act + result = channel_fixture.remove_stale_sns_subscriptions("queue-1") + + # Assert + assert result is None + assert mock_fanout.subscriptions.cleanup.call_count == 0 + + def test_remove_stale_sns_subscriptions_with_fanout_exchange( + self, mock_fanout, channel_fixture: SQS.Channel + ): + # Arrange + mock_fanout.subscriptions.cleanup.return_value = None + channel_fixture.supports_fanout = True + + exchange = Exchange("test_SQS_fanout", type="fanout") + queue = Queue("queue-1", exchange) + queue(channel_fixture).declare() + + # Act + result = channel_fixture.remove_stale_sns_subscriptions("test_SQS_fanout") + + # Assert + assert result is None + assert mock_fanout.subscriptions.cleanup.call_count == 1 + + def test_subscribe_queue_to_fanout_exchange_if_required_with_fanout( + self, mock_fanout, channel_fixture: SQS.Channel + ): + # Arrange + mock_fanout.subscriptions.subscribe_queue.return_value = None + channel_fixture.supports_fanout = True + + exchange = Exchange("test_SQS_fanout", type="fanout") + queue = Queue("queue-1", exchange) + queue(channel_fixture).declare() + + # Act + result = channel_fixture._subscribe_queue_to_fanout_exchange_if_required( + "queue-1" + ) + + # Assert + assert result is None + assert mock_fanout.subscriptions.subscribe_queue.call_args_list == [ + call(queue_name="queue-1", exchange_name="test_SQS_fanout") + ] + + def test_subscribe_queue_to_fanout_exchange_if_required_without_fanout( + self, mock_fanout, channel_fixture: SQS.Channel + ): + # Arrange + mock_fanout.cleanup.return_value = None + channel_fixture.supports_fanout = True + + exchange = Exchange("test_SQS_fanout", type="fanout") + queue = Queue("queue-1", exchange) + queue(channel_fixture).declare() + + exchange = Exchange("test_SQS", type="direct") + queue = Queue("queue-2", exchange) + queue(channel_fixture).declare() + + # Act + result = channel_fixture._subscribe_queue_to_fanout_exchange_if_required( + "queue-2" + ) + + # Assert + assert result is None + assert mock_fanout.subscribe_queue.call_count == 0 + + def test_subscribe_queue_to_fanout_exchange_if_required_not_defined( + self, mock_fanout, channel_fixture: SQS.Channel, caplog + ): + # Arrange + caplog.set_level(logging.DEBUG) + mock_fanout.cleanup.return_value = None + channel_fixture.supports_fanout = True + + exchange = Exchange("test_SQS_fanout", type="fanout") + queue = Queue("queue-1", exchange) + queue(channel_fixture).declare() + + # Act + result = channel_fixture._subscribe_queue_to_fanout_exchange_if_required( + "queue-2" + ) + + # Assert + assert result is None + assert mock_fanout.subscribe_queue.call_count == 0 + assert ( + "Not subscribing queue 'queue-2' to fanout exchange: Queue 'queue-2' has" + " not been defined." + ) in caplog.text + + @patch( + "kombu.transport.SQS.uuid.uuid4", + return_value="70c8cdfc-9bec-4d20-bbe1-3c155b794467", + ) + def test_put_fanout_fifo_queue( + self, _uuid_mock, mock_fanout, channel_fixture: SQS.Channel + ): + # Arrange + message = {"key1": "This is a value", "key2": 123, "key3": True} + + # Act + channel_fixture._put_fanout("queue-1.fifo", message, "") + + # Assert + assert mock_fanout.publish.call_args_list == [ + call( + exchange_name="queue-1.fifo", + message='{"key1": "This is a value", "key2": 123, "key3": true}', + message_attributes=None, + request_params={ + "MessageGroupId": "default", + "MessageDeduplicationId": "70c8cdfc-9bec-4d20-bbe1-3c155b794467", + }, + ) + ] + + def test_put_fanout_fifo_queue_custom_msg_groups( + self, mock_fanout, channel_fixture: SQS.Channel + ): + # Arrange + message = { + "key1": "This is a value", + "key2": 123, + "key3": True, + "properties": { + "MessageGroupId": "ThisIsNotDefault", + "MessageDeduplicationId": "MyDedupId", + }, + } + + # Act + channel_fixture._put_fanout("queue-1.fifo", message, "") + + # Assert + assert mock_fanout.publish.call_args_list == [ + call( + exchange_name="queue-1.fifo", + message=( + '{"key1": "This is a value", "key2": 123, "key3": true,' + ' "properties": {"MessageGroupId": "ThisIsNotDefault", ' + '"MessageDeduplicationId": "MyDedupId"}}' + ), + message_attributes=None, + request_params={ + "MessageGroupId": "ThisIsNotDefault", + "MessageDeduplicationId": "MyDedupId", + }, + ) + ] + + def test_put_fanout_non_fifo_queue(self, mock_fanout, channel_fixture: SQS.Channel): + # Arrange + message = {"key1": "This is a value", "key2": 123, "key3": True} + + # Act + channel_fixture._put_fanout("queue-1", message, "") + + # Assert + assert mock_fanout.publish.call_args_list == [ + call( + exchange_name="queue-1", + message='{"key1": "This is a value", "key2": 123, "key3": true}', + message_attributes=None, + request_params={}, + ) + ] + + def test_put_fanout_with_msg_attrs(self, mock_fanout, channel_fixture: SQS.Channel): + # Arrange + message = { + "key1": "This is a value", + "key2": 123, + "key3": True, + "properties": { + "message_attributes": {"attr1": "my-attribute-value", "attr2": 123}, + }, + } + + # Act + channel_fixture._put_fanout("queue-1", message, "") + + # Assert + assert mock_fanout.publish.call_args_list == [ + call( + exchange_name="queue-1", + message=( + '{"key1": "This is a value", "key2": 123, "key3": true, ' + '"properties": {"message_attributes": {"attr1": ' + '"my-attribute-value", "attr2": 123}}}' + ), + message_attributes={"attr1": "my-attribute-value", "attr2": 123}, + request_params={}, + ) + ] + + @pytest.mark.parametrize( + "sf_transport_value, expected_result", + [(True, True), (False, False), (None, False)], + ) + def test_supports_fanout( + self, sf_transport_value, expected_result, channel_fixture + ): + # Arrange + if sf_transport_value is not None: + channel_fixture.transport_options["supports_fanout"] = sf_transport_value + + # Act & Assert + assert channel_fixture.supports_fanout == expected_result + + def test_sqs_client_already_initialised(self, channel_fixture, mock_new_sqs_client): + # Arrange + sqs_client_mock = Mock(name="My SQS client") + channel_fixture._sqs = sqs_client_mock + + # Act + result = SQS_Channel_sqs.__get__(channel_fixture, SQS.Channel)() + + # Assert + assert result is sqs_client_mock + assert mock_new_sqs_client.call_count == 0 + + def test_sqs_client_predefined_queue_not_defined( + self, channel_fixture, mock_new_sqs_client + ): + # Arrange + channel_fixture._sqs = None + + # Act + with pytest.raises( + UndefinedQueueException, + match="Queue with name 'queue-4' must be defined in 'predefined_queues'.", + ): + SQS_Channel_sqs.__get__(channel_fixture, SQS.Channel)(queue="queue-4") + + # Assert + assert mock_new_sqs_client.call_count == 0 + + def test_sqs_client_predefined_queue_already_has_client( + self, channel_fixture, mock_new_sqs_client + ): + # Arrange + mock_client = Mock(name="My SQS client") + channel_fixture._sqs = None + channel_fixture._predefined_queue_clients["queue-1"] = mock_client + + # Act + result = SQS_Channel_sqs.__get__(channel_fixture, SQS.Channel)(queue="queue-1") + + # Assert + assert result == mock_client + assert mock_new_sqs_client.call_count == 0 + + def test_sqs_client_predefined_queue_does_not_have_client( + self, channel_fixture, mock_new_sqs_client + ): + # Arrange + queue_2_client = Mock(name="My new SQS client") + queue_1_client = Mock(name="A different SQS client") + channel_fixture._sqs = None + channel_fixture._predefined_queue_clients = {"queue-1": queue_1_client} + mock_new_sqs_client.return_value = queue_2_client + + # Act + result = SQS_Channel_sqs.__get__(channel_fixture, SQS.Channel)(queue="queue-2") + + # Assert + assert channel_fixture._predefined_queue_clients == { + "queue-1": queue_1_client, + "queue-2": queue_2_client, + } + assert result == queue_2_client + assert mock_new_sqs_client.call_args_list == [ + call(region="some-aws-region", access_key_id="c", secret_access_key="d") + ] + + def test_fanout_instance_already_initialised(self, channel_fixture): + # Arrange + sns_fanout_mock = Mock(name="SNS Fanout Class") + channel_fixture._fanout = sns_fanout_mock + + # Act + result = channel_fixture.fanout + + # Assert + assert result is sns_fanout_mock + + def test_fanout_client_not_initialised(self, channel_fixture): + with patch("kombu.transport.SQS.SNS") as fan_mock: + # Arrange + channel_fixture._fanout = None + + # Act + result = channel_fixture.fanout + + # Assert + assert fan_mock.call_args_list == [call(channel_fixture)] + assert result == fan_mock() + + @pytest.mark.parametrize("exchanges", [None, example_predefined_exchanges]) + def test_predefined_exchanges(self, exchanges, channel_fixture): + # Arrange + if exchanges is None: + channel_fixture.transport_options.pop("predefined_exchanges", None) + else: + channel_fixture.transport_options["predefined_exchanges"] = exchanges + + # Act & Assert + expected_result = exchanges if exchanges is not None else {} + assert channel_fixture.predefined_exchanges == expected_result + + @pytest.mark.parametrize('fetch_attributes, call_args', [ + ( + { + 'MessageSystemAttributeNames': ['SenderId', 'SentTimestamp'], + 'MessageAttributeNames': ['CustomProp', 'CustomProp2'] + }, + { + 'MessageSystemAttributeNames': ['ApproximateReceiveCount', 'SenderId', 'SentTimestamp'], + 'MessageAttributeNames': ['CustomProp', 'CustomProp2'] + }, + ), + ( + { + 'MessageSystemAttributeNames': ['SenderId', 'SentTimestamp'], + 'MessageAttributeNames': [] + }, + { + 'MessageSystemAttributeNames': ['ApproximateReceiveCount', 'SenderId', 'SentTimestamp'], + }, + ), + ( + { + 'MessageSystemAttributeNames': ['SenderId', 'SentTimestamp'], + 'MessageAttributeNames': None + }, + { + 'MessageSystemAttributeNames': ['ApproximateReceiveCount', 'SenderId', 'SentTimestamp'], + }, + ), + ( + { + 'MessageSystemAttributeNames': ['ApproximateReceiveCount', 'SenderId', 'SentTimestamp'], + }, + { + 'MessageSystemAttributeNames': ['ApproximateReceiveCount', 'SenderId', 'SentTimestamp'], + }, + ), + ( + { + 'MessageSystemAttributeNames': [], + 'MessageAttributeNames': ['CustomProp', 'CustomProp2'] + }, + { + 'MessageSystemAttributeNames': ['ApproximateReceiveCount'], + 'MessageAttributeNames': ['CustomProp', 'CustomProp2'] + }, + ), + ( + { + 'MessageSystemAttributeNames': None, + 'MessageAttributeNames': ['CustomProp', 'CustomProp2'] + }, + { + 'MessageSystemAttributeNames': ['ApproximateReceiveCount'], + 'MessageAttributeNames': ['CustomProp', 'CustomProp2'] + }, + ), + ( + { + 'MessageAttributeNames': ['CustomProp', 'CustomProp2'] + }, + { + 'MessageSystemAttributeNames': ['ApproximateReceiveCount'], + 'MessageAttributeNames': ['CustomProp', 'CustomProp2'] + }, + ) + ]) + def test_receive_message(self, fetch_attributes, call_args): + # Arrange + self.connection.transport_options["fetch_message_attributes"] = fetch_attributes + + with patch.object(self.sqs_conn_mock, 'receive_message', wraps=self.sqs_conn_mock.receive_message, + name="RxSpy") as rx_spy: + self.sqs_conn_mock.send_message(QueueUrl=f'https://sqs.us-east-1.amazonaws.com/xxx/{self.queue_name}', + MessageBody='This is a test') + + # Act + result = self.channel._receive_message(self.queue_name) + + # Assert + assert 1 == len(result["Messages"]) + assert 'This is a test' == result["Messages"][0]['Body'] + assert rx_spy.call_args_list == [ + call( + QueueUrl=f'https://sqs.us-east-1.amazonaws.com/xxx/{self.queue_name}', + MaxNumberOfMessages=1, WaitTimeSeconds=10, + **call_args + )] + + @pytest.mark.freeze_time("2025-11-05 14:57:12") + def test_generate_sts_session_token(self, mock_boto_client): + # Arrange + role_arn = "arn:aws:iam::123456789012:role/role-a" + token_expiry = 123 + + # Act + session_token = self.channel.generate_sts_session_token(role_arn, token_expiry) + + # Assert + assert session_token == { + 'AccessKeyId': 'AKIAIOSFODNN7EXAMPLE', + 'Expiration': datetime(2025, 11, 5, 14, 57, 12) + timedelta(seconds=token_expiry), + 'SecretAccessKey': 'wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY', + 'SessionToken': 'AQoDYXdzEPT//////////wEXAMPLEtc764bNrC9SAPBSM22wDOk4x4HIZ8j4FZTwdQWLWsKWHGBuFq' + 'wAeMicRXmxfpSPfIeoIYRqTflfKD8YUuwthAx7mSEI/qkPpKPi/kMcGdQrmGdeehM4IC1NtBmUpp2wU' + 'E8phUZampKsburEDy0KPkyQDYwT7WZ0wq5VSXDvp75YU9HFvlRd8Tx6q6fE8YQcHNVXAkiY9q6d+xo' + '0rKwT38xVqr7ZD0u0iPPkUL64lIZbqBAz+scqKmlzm8FDrypNC9Yjc8fPOLn9FX9KSYvKTr4rvx3iSI' + 'lTJabIQwj2ICCR/oLxBA==', + } diff --git a/t/unit/transport/SQS/test_SQS_SNS.py b/t/unit/transport/SQS/test_SQS_SNS.py new file mode 100644 index 0000000000..3e4faae780 --- /dev/null +++ b/t/unit/transport/SQS/test_SQS_SNS.py @@ -0,0 +1,1245 @@ +"""Testing module for the kombu.transport.SQS package. + +NOTE: The SQSQueueMock and SQSConnectionMock classes originally come from +http://github.com/pcsforeducation/sqs-mock-python. They have been patched +slightly. +""" +from __future__ import annotations + +import json +import logging +from datetime import datetime, timedelta, timezone +from unittest.mock import Mock, call, patch + +import pytest + +from kombu import Exchange, Queue +from kombu.exceptions import KombuError +from kombu.transport.SQS.exceptions import UndefinedExchangeException + +boto3 = pytest.importorskip('boto3') + +from botocore.exceptions import ClientError # noqa + +from kombu.transport import SQS # noqa + +SQS_Channel_sqs = SQS.Channel.sqs + +example_predefined_queues = { + 'queue-1': { + 'url': 'https://sqs.us-east-1.amazonaws.com/xxx/queue-1', + 'access_key_id': 'a', + 'secret_access_key': 'b', + 'backoff_tasks': ['svc.tasks.tasks.task1'], + 'backoff_policy': {1: 10, 2: 20, 3: 40, 4: 80, 5: 320, 6: 640} + }, + 'queue-2': { + 'url': 'https://sqs.us-east-1.amazonaws.com/xxx/queue-2', + 'access_key_id': 'c', + 'secret_access_key': 'd', + }, + "queue-3.fifo": { + "url": "https://sqs.us-east-1.amazonaws.com/xxx/queue-3.fifo", + "access_key_id": "e", + "secret_access_key": "f", + }, +} + + +class test_SNS: + @pytest.fixture + def mock_sts_credentials(self): + return { + "AccessKeyId": "test_access_key", + "SecretAccessKey": "test_secret_key", + "SessionToken": "test_session_token", + "Expiration": datetime.now(timezone.utc) + timedelta(hours=1), + } + + @pytest.mark.parametrize("exchange_name", ["test_exchange"]) + def test_initialise_exchange_with_existing_topic(self, sns_fanout, exchange_name): + # Arrange + sns_fanout._topic_arn_cache[exchange_name] = "existing_arn" + sns_fanout.subscriptions = Mock() + + # Act + result = sns_fanout.initialise_exchange(exchange_name) + + # Assert + assert result is None + assert sns_fanout.subscriptions.cleanup.call_args_list == [call(exchange_name)] + assert sns_fanout._topic_arn_cache[exchange_name] == "existing_arn" + + def test_initialise_exchange_with_predefined_exchanges(self, sns_fanout, caplog): + # Arrange + exchange_name = "test_exchange" + + sns_fanout.channel.predefined_exchanges = {"exchange-1": {}} + sns_fanout.subscriptions = Mock() + caplog.set_level(logging.DEBUG) + + # Act + result = sns_fanout.initialise_exchange(exchange_name) + + # Assert + assert result is None + assert sns_fanout.subscriptions.cleanup.call_args_list == [call(exchange_name)] + assert ( + "'predefined_exchanges' has been specified, so SNS topics will not be created." + in caplog.text + ) + + def test_initialise_exchange_create_new_topic(self, sns_fanout): + # Arrange + exchange_name = "test_exchange" + + sns_fanout.channel.predefined_exchanges = False + sns_fanout.subscriptions = Mock() + sns_fanout._create_sns_topic = Mock(return_value="new_arn") + + # Act + result = sns_fanout.initialise_exchange(exchange_name) + + # Assert + assert result is None + assert sns_fanout.subscriptions.cleanup.call_args_list == [call(exchange_name)] + assert sns_fanout._create_sns_topic.call_args_list == [call(exchange_name)] + assert sns_fanout._topic_arn_cache[exchange_name] == "new_arn" + + def test_get_topic_arn_create_new_topic(self, sns_fanout): + # Arrange + exchange_name = "test_exchange" + + sns_fanout.channel.predefined_exchanges = {} + sns_fanout._create_sns_topic = Mock(return_value="new_arn") + + # Act + result = sns_fanout._get_topic_arn(exchange_name) + + # Assert + assert result == "new_arn" + assert sns_fanout._create_sns_topic.call_args_list == [call(exchange_name)] + assert sns_fanout._topic_arn_cache[exchange_name] == "new_arn" + + def test_get_topic_arn_predefined_exchange_found(self, sns_fanout): + # Arrange + exchange_name = "exchange-1" + + sns_fanout.channel.predefined_exchanges = { + "exchange-1": {"arn": "some-existing-arn"} + } + sns_fanout._create_sns_topic = Mock(return_value="new_arn") + + # Act + result = sns_fanout._get_topic_arn(exchange_name) + + # Assert + assert result == "some-existing-arn" + assert sns_fanout._create_sns_topic.call_count == 0 + assert sns_fanout._topic_arn_cache[exchange_name] == "some-existing-arn" + + def test_get_topic_arn_predefined_exchange_not_found(self, sns_fanout): + # Arrange + exchange_name = "exchange-2" + + sns_fanout.channel.predefined_exchanges = { + "exchange-1": {"arn": "some-existing-arn"} + } + sns_fanout._create_sns_topic = Mock(return_value="new_arn") + + # Act + with pytest.raises( + UndefinedExchangeException, + match="Exchange with name 'exchange-2' must be defined in 'predefined_exchanges'.", + ): + sns_fanout._get_topic_arn(exchange_name) + + # Assert + assert sns_fanout._create_sns_topic.call_count == 0 + assert sns_fanout._topic_arn_cache.get(exchange_name) is None + + def test_publish_successful(self, sns_fanout): + # Arrange + exchange_name = "test_exchange" + message = "test_message" + + sns_fanout._topic_arn_cache[exchange_name] = "existing_arn" + mock_client = Mock() + mock_client.publish.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}} + sns_fanout.get_client = Mock(return_value=mock_client) + + # Act + sns_fanout.publish(exchange_name, message) + + # Assert + assert sns_fanout.get_client.call_args_list == [call(exchange_name)] + assert mock_client.publish.call_args_list == [ + call(TopicArn="existing_arn", Message="test_message") + ] + + def test_publish_with_attributes_and_params(self, sns_fanout): + # Arrange + exchange_name = "test_exchange" + message = "test_message" + message_attributes = {"attr1": "value1", "attr2": 123, "A boolean?": True} + request_params = {"param1": "value1"} + + sns_fanout._topic_arn_cache[exchange_name] = "existing_arn" + mock_client = Mock() + mock_client.publish.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}} + sns_fanout.get_client = Mock(return_value=mock_client) + + # Act + sns_fanout.publish(exchange_name, message, message_attributes, request_params) + + # Assert + assert sns_fanout.get_client.call_args_list == [((exchange_name,), {})] + assert mock_client.publish.call_args_list == [ + call( + TopicArn="existing_arn", + Message="test_message", + param1="value1", + MessageAttributes={ + "attr1": {"DataType": "String", "StringValue": "value1"}, + "attr2": {"DataType": "String", "StringValue": "123"}, + "A boolean?": {"DataType": "String", "StringValue": "True"}, + }, + ) + ] + + def test_publish_failure(self, sns_fanout): + # Arrange + exchange_name = "test_exchange" + message = "test_message" + + sns_fanout._topic_arn_cache[exchange_name] = "existing_arn" + + mock_client = Mock() + mock_client.publish.return_value = {"ResponseMetadata": {"HTTPStatusCode": 400}} + sns_fanout.get_client = Mock(return_value=mock_client) + + # Act and Assert + with pytest.raises( + UndefinedExchangeException, + match="Unable to send message to topic 'existing_arn': status code was 400", + ): + sns_fanout.publish("test_exchange", message) + + assert sns_fanout.get_client.call_args_list == [call(exchange_name)] + assert mock_client.publish.call_args_list == [ + call(TopicArn="existing_arn", Message="test_message") + ] + + def test_create_sns_topic_success(self, sns_fanout, caplog): + # Arrange + caplog.set_level(logging.DEBUG) + sns_fanout.get_client = Mock() + mock_client = sns_fanout.get_client.return_value + mock_client.create_topic.return_value = { + "ResponseMetadata": {"HTTPStatusCode": 200}, + "TopicArn": "arn:aws:sns:us-east-1:123456789012:my-new-topic", + } + + # Act + result = sns_fanout._create_sns_topic("my-new-topic") + + # Assert + assert result == "arn:aws:sns:us-east-1:123456789012:my-new-topic" + assert mock_client.create_topic.call_args_list == [ + call( + Name="my-new-topic", + Attributes={"FifoTopic": "False"}, + Tags=[ + {"Key": "ManagedBy", "Value": "Celery/Kombu"}, + { + "Key": "Description", + "Value": "This SNS topic is used by Kombu to enable Fanout support for AWS SQS.", + }, + ], + ) + ] + assert "Creating SNS topic 'my-new-topic'" in caplog.text + assert ( + "Created SNS topic 'my-new-topic' with ARN 'arn:aws:sns:us-east-1:123456789012:my-new-topic'" + in caplog.text + ) + + def test_create_sns_topic_failure(self, sns_fanout): + # Arrange + sns_fanout.get_client = Mock() + mock_client = sns_fanout.get_client.return_value + mock_client.create_topic.return_value = { + "ResponseMetadata": {"HTTPStatusCode": 400} + } + + # Act and Assert + with pytest.raises( + UndefinedExchangeException, match="Unable to create SNS topic" + ): + sns_fanout._create_sns_topic("test_exchange") + + def test_create_sns_topic_fifo(self, sns_fanout, caplog): + # Arrange + caplog.set_level(logging.DEBUG) + sns_fanout.get_client = Mock() + mock_client = sns_fanout.get_client.return_value + mock_client.create_topic.return_value = { + "ResponseMetadata": {"HTTPStatusCode": 200}, + "TopicArn": "arn:aws:sns:us-east-1:123456789012:test_topic.fifo", + } + + # Act + result = sns_fanout._create_sns_topic("test_topic.fifo") + + # Assert + assert result == "arn:aws:sns:us-east-1:123456789012:test_topic.fifo" + assert mock_client.create_topic.call_args_list == [ + call( + Name="test_topic.fifo", + Attributes={"FifoTopic": "True"}, + Tags=[ + {"Key": "ManagedBy", "Value": "Celery/Kombu"}, + { + "Key": "Description", + "Value": "This SNS topic is used by Kombu to enable Fanout support for AWS SQS.", + }, + ], + ) + ] + assert "Creating SNS topic 'test_topic.fifo'" in caplog.text + assert ( + "Created SNS topic 'test_topic.fifo' with ARN 'arn:aws:sns:us-east-1:123456789012:test_topic.fifo'" + in caplog.text + ) + + def test_get_client_predefined_exchange(self, sns_fanout): + # Arrange + sns_fanout.channel.predefined_exchanges = { + "test_exchange": {"region": "us-west-2"} + } + sns_fanout._create_boto_client = Mock() + + # Act + result = sns_fanout.get_client("test_exchange") + + # Assert + assert result == sns_fanout._create_boto_client.return_value + assert sns_fanout._create_boto_client.call_args_list == [ + call(region="us-west-2", access_key_id=None, secret_access_key=None) + ] + + def test_get_client_undefined_exchange(self, sns_fanout): + # Arrange + sns_fanout.channel.predefined_exchanges = {"exchange-1": {}} + + # Act & Assert + with pytest.raises( + UndefinedExchangeException, + match="Exchange with name 'test_exchange' must be defined in 'predefined_exchanges'.", + ): + sns_fanout.get_client("test_exchange") + + def test_get_client_sts_session(self, sns_fanout): + # Arrange + sns_fanout.channel.predefined_exchanges = { + "test_exchange": { + "arn": "test_arn", + } + } + sns_fanout.channel.connection.client.transport_options = { + "sts_role_arn": "test_arn" + } + sns_fanout._handle_sts_session = Mock() + + # Act + result = sns_fanout.get_client("test_exchange") + + # Assert + assert result == sns_fanout._handle_sts_session.return_value + assert sns_fanout._handle_sts_session.call_args_list == [ + call("test_exchange", {"arn": "test_arn"}) + ] + + def test_get_client_existing_predefined_client(self, sns_fanout): + # Arrange + sns_fanout.channel.predefined_exchanges = { + "test_exchange": { + "arn": "test_arn", + } + } + client_mock = Mock() + sns_fanout._predefined_clients = {"test_exchange": client_mock} + + # Act + result = sns_fanout.get_client("test_exchange") + + # Assert + assert result is client_mock + + def test_get_client_existing_client(self, sns_fanout): + # Arrange + sns_fanout._client = Mock() + + # Act + result = sns_fanout.get_client() + + # Assert + assert result == sns_fanout._client + + def test_get_client_new_client(self, sns_fanout): + # Arrange + sns_fanout._create_boto_client = Mock() + sns_fanout.channel.conninfo.userid = "MyAccessKeyID" + sns_fanout.channel.conninfo.password = "MySecretAccessKey" + + # Act + result = sns_fanout.get_client() + + # Assert + assert result == sns_fanout._create_boto_client.return_value + assert ( + sns_fanout._create_boto_client.call_args_list + == [ + call( + region="some-aws-region", + access_key_id="MyAccessKeyID", + secret_access_key="MySecretAccessKey", + ) + ] + ) + + def test_token_refresh_required_no_date(self, sns_fanout): + # Arrange + exchange_name = "test_exchange" + exchange_config = {"region": "us-west-2"} + + create_session_mock = Mock() + sns_fanout._create_boto_client_with_sts_session = create_session_mock + + # Act + result = sns_fanout._handle_sts_session(exchange_name, exchange_config) + + # Assert + assert result == sns_fanout._create_boto_client_with_sts_session.return_value + assert create_session_mock.call_args_list == [ + call("test_exchange", region="us-west-2") + ] + + def test_token_refresh_required_expired_date(self, sns_fanout): + # Arrange + exchange_name = "test_exchange" + exchange_config = {"region": "us-west-2"} + + create_session_mock = Mock() + sns_fanout._create_boto_client_with_sts_session = create_session_mock + sns_fanout.sts_expiration = datetime.now(timezone.utc) - timedelta(minutes=1) + + client_mock = Mock() + sns_fanout._predefined_clients = {"test_exchange": client_mock} + + # Act + result = sns_fanout._handle_sts_session(exchange_name, exchange_config) + + # Assert + assert result == sns_fanout._create_boto_client_with_sts_session.return_value + assert create_session_mock.call_args_list == [ + call("test_exchange", region="us-west-2") + ] + + def test_token_refresh_required_non_expired_date_without_client(self, sns_fanout): + # Arrange + exchange_name = "test_exchange" + exchange_config = {"region": "us-west-2"} + + create_session_mock = Mock() + sns_fanout._create_boto_client_with_sts_session = create_session_mock + sns_fanout.sts_expiration = datetime.now(timezone.utc) + timedelta(minutes=1) + client_mock = Mock() + + sns_fanout._predefined_clients = {"another-exchange": client_mock} + + # Act + result = sns_fanout._handle_sts_session(exchange_name, exchange_config) + + # Assert + assert result == sns_fanout._create_boto_client_with_sts_session.return_value + assert create_session_mock.call_args_list == [ + call("test_exchange", region="us-west-2") + ] + + def test_token_refresh_required_non_expired_date_with_client(self, sns_fanout): + # Arrange + exchange_name = "test_exchange" + exchange_config = {"region": "us-west-2"} + + create_session_mock = Mock() + sns_fanout._create_boto_client_with_sts_session = create_session_mock + sns_fanout.sts_expiration = datetime.now(timezone.utc) + timedelta(minutes=1) + client_mock = Mock() + + sns_fanout._predefined_clients = {exchange_name: client_mock} + + # Act + result = sns_fanout._handle_sts_session(exchange_name, exchange_config) + + # Assert + assert create_session_mock.call_count == 0 + assert result is client_mock + + def test_create_boto_client_with_sts_session( + self, sns_fanout, mock_sts_credentials + ): + # Arrange + exchange_name = "test_exchange" + region = "us-west-2" + sns_fanout.channel.get_sts_credentials = Mock(return_value=mock_sts_credentials) + + boto_client_mock = Mock(name="My new boto client") + sns_fanout.channel._new_boto_client = Mock(return_value=boto_client_mock) + + # Act + result = sns_fanout._create_boto_client_with_sts_session(exchange_name, region) + + # Assert + assert result is boto_client_mock + + # Check class vars have been updated + assert sns_fanout.sts_expiration == mock_sts_credentials["Expiration"] + assert sns_fanout._predefined_clients[exchange_name] == boto_client_mock + + # Check calls + assert sns_fanout.channel.get_sts_credentials.call_args_list == [call()] + assert sns_fanout.channel._new_boto_client.call_args_list == [ + call( + service="sns", + region="us-west-2", + access_key_id="test_access_key", + secret_access_key="test_secret_key", + session_token="test_session_token", + ) + ] + + +class test_SnsSubscription: + @pytest.fixture + def mock_get_queue_arn(self, sns_subscription): + with patch.object(sns_subscription, "_get_queue_arn") as mock: + yield mock + + @pytest.fixture + def mock_get_topic_arn(self, sns_fanout): + with patch.object(sns_fanout, "_get_topic_arn") as mock: + yield mock + + @pytest.fixture + def mock_get_client(self, sns_fanout): + with patch.object(sns_fanout, "get_client") as mock: + yield mock + + @pytest.fixture + def mock_set_permission_on_sqs_queue(self, sns_subscription): + with patch.object(sns_subscription, "_set_permission_on_sqs_queue") as mock: + yield mock + + @pytest.fixture + def mock_subscribe_queue_to_sns_topic(self, sns_subscription): + with patch.object(sns_subscription, "_subscribe_queue_to_sns_topic") as mock: + yield mock + + def test_subscribe_queue_already_subscribed( + self, + sns_subscription, + sns_fanout, + mock_get_client + ): + # Arrange + queue_name = "test_queue" + exchange_name = "test_exchange" + cached_subscription_arn = "arn:aws:sns:us-east-1:123456789012:test_topic:cached" + sns_subscription._subscription_arn_cache[f"{exchange_name}:{queue_name}"] = ( + cached_subscription_arn + ) + + # Act + result = sns_subscription.subscribe_queue(queue_name, exchange_name) + + # Assert + assert result == cached_subscription_arn + assert mock_get_client.call_count == 0 + + def test_subscribe_queue_success_queue_in_cache( + self, + sns_subscription, + caplog, + mock_get_topic_arn, + mock_get_queue_arn, + mock_set_permission_on_sqs_queue, + mock_subscribe_queue_to_sns_topic, + ): + # Arrange + queue_name = "test_queue" + exchange_name = "test_exchange" + queue_arn = "arn:aws:sqs:us-east-1:123456789012:test_queue" + topic_arn = "arn:aws:sns:us-east-1:123456789012:test_topic" + subscription_arn = "arn:aws:sns:us-east-1:123456789012:test_topic:12345678-1234-1234-1234-123456789012" + + mock_get_queue_arn.return_value = queue_arn + mock_get_topic_arn.return_value = topic_arn + mock_subscribe_queue_to_sns_topic.return_value = subscription_arn + + # Act + result = sns_subscription.subscribe_queue(queue_name, exchange_name) + + # Assert + assert result == subscription_arn + assert ( + sns_subscription._subscription_arn_cache[f"{exchange_name}:{queue_name}"] + == subscription_arn + ) + assert mock_get_queue_arn.call_args_list == [call("test_queue")] + assert mock_get_topic_arn.call_args_list == [call(exchange_name)] + assert mock_subscribe_queue_to_sns_topic.call_args_list == [ + call(topic_arn=topic_arn, queue_arn=queue_arn) + ] + assert mock_set_permission_on_sqs_queue.call_args_list == [ + call(topic_arn=topic_arn, queue_arn=queue_arn, queue_name=queue_name) + ] + + def test_unsubscribe_queue_not_in_cache( + self, + sns_subscription, + ): + # Arrange + queue_name = "test_queue" + exchange_name = "test_exchange" + sns_subscription._subscription_arn_cache = { + "another-exchange:another_queue": "123" + } + sns_subscription._unsubscribe_sns_subscription = Mock() + + # Act + result = sns_subscription.unsubscribe_queue(queue_name, exchange_name) + + # Assert + assert result is None + assert sns_subscription._unsubscribe_sns_subscription.call_count == 0 + + def test_unsubscribe_queue_in_cache(self, sns_subscription, caplog): + # Arrange + caplog.set_level(logging.DEBUG) + queue_name = "test_queue" + exchange_name = "test_exchange" + subscription_arn = "arn:aws:sns:us-east-1:123456789012:test_topic:12345678-1234-1234-1234-123456789012" + sns_subscription._subscription_arn_cache = { + "test_exchange:test_queue": subscription_arn + } + sns_subscription._unsubscribe_sns_subscription = Mock() + + # Act + result = sns_subscription.unsubscribe_queue(queue_name, exchange_name) + + # Assert + assert result is None + assert ( + f"Unsubscribed subscription '{subscription_arn}' for SQS queue '{queue_name}'" + in caplog.text + ) + assert sns_subscription._unsubscribe_sns_subscription.call_args_list == [ + call(subscription_arn) + ] + + def test_cleanup_with_predefined_exchanges( + self, sns_subscription, caplog, channel_fixture, sns_fanout + ): + # Arrange + caplog.set_level(logging.DEBUG) + + exchange_name = "exchange-1" + + channel_fixture.predefined_exchanges = {"exchange-1": {}} + sns_fanout._get_topic_arn = Mock() + + # Act + result = sns_subscription.cleanup(exchange_name) + + # Assert + assert result is None + assert ( + "'predefined_exchanges' has been specified, so stale SNS subscription" + " cleanup will be skipped." + ) in caplog.text + assert sns_fanout._get_topic_arn.call_count == 0 + + def test_cleanup_no_invalid_subscriptions( + self, sns_subscription, caplog, channel_fixture, sns_fanout + ): + # Arrange + caplog.set_level(logging.DEBUG) + + topic_arn = "arn:aws:sns:us-east-1:123456789012:my-topic" + exchange_name = "exchange-1" + + channel_fixture.predefined_exchanges = {} + sns_fanout._get_topic_arn = Mock(return_value=topic_arn) + sns_subscription._get_invalid_sns_subscriptions = Mock(return_value=[]) + sns_subscription._unsubscribe_sns_subscription = Mock() + + # Act + result = sns_subscription.cleanup(exchange_name) + + # Assert + assert result is None + assert ( + f"Checking for stale SNS subscriptions for exchange '{exchange_name}'" + ) in caplog.text + assert sns_fanout._get_topic_arn.call_args_list == [call(exchange_name)] + assert sns_subscription._unsubscribe_sns_subscription.call_count == 0 + + def test_cleanup_with_invalid_subscriptions( + self, sns_subscription, caplog, channel_fixture, sns_fanout + ): + # Arrange + caplog.set_level(logging.DEBUG) + + topic_arn = "arn:aws:sns:us-east-1:123456789012:my-topic" + exchange_name = "exchange-1" + + channel_fixture.predefined_exchanges = {} + sns_fanout._get_topic_arn = Mock(return_value=topic_arn) + sns_subscription._get_invalid_sns_subscriptions = Mock( + return_value=[ + "subscription-arn-1", + "subscription-arn-2", + "subscription-arn-3", + ] + ) + + # Ensure that we carry on after hitting an exception + sns_subscription._unsubscribe_sns_subscription = Mock( + side_effect=[None, ConnectionError("A test exception"), None] + ) + + # Act + result = sns_subscription.cleanup(exchange_name) + + # Assert + assert result is None + assert sns_fanout._get_topic_arn.call_args_list == [call(exchange_name)] + assert sns_subscription._unsubscribe_sns_subscription.call_args_list == [ + call("subscription-arn-1"), + call("subscription-arn-2"), + call("subscription-arn-3"), + ] + + # Check logs + log_lines = [ + f"Removed stale subscription 'subscription-arn-1' for SNS topic '{topic_arn}'", + f"Failed to remove stale subscription 'subscription-arn-2' for SNS topic" + f" '{topic_arn}': A test exception", + f"Removed stale subscription 'subscription-arn-3' for SNS topic '{topic_arn}'", + ] + for line in log_lines: + assert line in caplog.text + + def test_set_permission_on_sqs_queue( + self, sns_subscription, caplog, mock_sqs, channel_fixture + ): + # Arrange + caplog.set_level(logging.DEBUG) + + topic_arn = "arn:aws:sns:us-east-1:123456789012:my-topic" + queue_name = "my-queue" + queue_arn = "arn:aws:sqs:us-east-1:123456789012:my-queue" + + channel_fixture.predefined_queues = {} + channel_fixture.sqs.return_value = mock_sqs() + channel_fixture._queue_cache[queue_name] = ( + "https://sqs.us-east-1.amazonaws.com/123456789012/my-queue" + ) + + exchange = Exchange("test_SQS", type="direct") + queue = Queue(queue_name, exchange) + queue(channel_fixture).declare() + + # Act + sns_subscription._set_permission_on_sqs_queue(topic_arn, queue_name, queue_arn) + + # Assert + expected_policy = { + "Version": "2012-10-17", + "Statement": [ + { + "Sid": "KombuManaged", + "Effect": "Allow", + "Principal": {"Service": "sns.amazonaws.com"}, + "Action": "SQS:SendMessage", + "Resource": queue_arn, + "Condition": {"ArnLike": {"aws:SourceArn": topic_arn}}, + } + ], + } + + assert mock_sqs().set_queue_attributes.call_args_list == [ + call( + QueueUrl="https://sqs.us-east-1.amazonaws.com/123456789012/my-queue", + Attributes={"Policy": json.dumps(expected_policy)}, + ) + ] + + assert ( + "Set permissions on SNS topic 'arn:aws:sns:us-east-1:123456789012:my-topic'" + ) in caplog.text + + def test_subscribe_queue_to_sns_topic_successful_subscription( + self, sns_subscription, caplog, sns_fanout, mock_get_client + ): + # Arrange + caplog.set_level(logging.DEBUG) + + queue_arn = "arn:aws:sqs:us-west-2:123456789012:my-queue" + topic_arn = "arn:aws:sns:us-west-2:123456789012:my-topic" + subscription_arn = "arn:aws:sns:us-west-2:123456789012:my-topic:12345678-1234-1234-1234-123456789012" + mock_client = Mock() + mock_client.return_value.subscribe.return_value = { + "SubscriptionArn": subscription_arn, + "ResponseMetadata": {"HTTPStatusCode": 200}, + } + sns_fanout.get_client = mock_client + + # Act + result = sns_subscription._subscribe_queue_to_sns_topic(queue_arn, topic_arn) + + # Assert + assert result == subscription_arn + assert mock_get_client.call_args_list == [] + assert mock_client.return_value.subscribe.call_args_list == [ + call( + TopicArn="arn:aws:sns:us-west-2:123456789012:my-topic", + Protocol="sqs", + Endpoint="arn:aws:sqs:us-west-2:123456789012:my-queue", + Attributes={"RawMessageDelivery": "true"}, + ReturnSubscriptionArn=True, + ) + ] + assert ( + f"Subscribing queue '{queue_arn}' to SNS topic '{topic_arn}'" in caplog.text + ) + assert ( + f"Create subscription '{subscription_arn}' for SQS queue '{queue_arn}' to SNS topic '{topic_arn}'" + in caplog.text + ) + + def test_subscribe_queue_to_sns_topic_subscription_failure( + self, sns_subscription, sns_fanout + ): + # Arrange + queue_arn = "arn:aws:sqs:us-west-2:123456789012:my-queue" + topic_arn = "arn:aws:sns:us-west-2:123456789012:my-topic" + mock_client = Mock() + mock_client.return_value.subscribe.return_value = { + "ResponseMetadata": {"HTTPStatusCode": 400} + } + sns_fanout.get_client = mock_client + + # Act and Assert + with pytest.raises( + Exception, match="Unable to subscribe queue: status code was 400" + ): + sns_subscription._subscribe_queue_to_sns_topic(queue_arn, topic_arn) + + def test_subscribe_queue_to_sns_topic_client_error( + self, sns_subscription, sns_fanout + ): + # Arrange + queue_arn = "arn:aws:sqs:us-west-2:123456789012:my-queue" + topic_arn = "arn:aws:sns:us-west-2:123456789012:my-topic" + + mock_client = Mock() + mock_client.return_value.subscribe.side_effect = ClientError( + error_response={"Error": {"Code": "InvalidParameter"}}, + operation_name="Subscribe", + ) + sns_fanout.get_client = mock_client + + # Act and Assert + with pytest.raises(ClientError): + sns_subscription._subscribe_queue_to_sns_topic(queue_arn, topic_arn) + + def test_unsubscribe_sns_subscription_success(self, sns_subscription, sns_fanout): + # Arrange + subscription_arn = ( + "arn:aws:sns:us-west-2:123456789012:my-topic:12345678-12:sub-id" + ) + + mock_client = Mock() + mock_client.return_value.unsubscribe.return_value = { + "ResponseMetadata": {"HTTPStatusCode": 200} + } + sns_fanout.get_client = mock_client + + # Act + result = sns_subscription._unsubscribe_sns_subscription(subscription_arn) + + # Assert + assert result is None + assert mock_client.return_value.unsubscribe.call_args_list == [ + call(SubscriptionArn=subscription_arn) + ] + + def test_unsubscribe_sns_subscription_error( + self, sns_subscription, sns_fanout, caplog + ): + # Arrange + caplog.set_level(logging.DEBUG) + subscription_arn = ( + "arn:aws:sns:us-west-2:123456789012:my-topic:12345678-12:sub-id" + ) + + mock_client = Mock() + mock_client.return_value.unsubscribe.return_value = { + "ResponseMetadata": {"HTTPStatusCode": 400} + } + sns_fanout.get_client = mock_client + + # Act + result = sns_subscription._unsubscribe_sns_subscription(subscription_arn) + + # Assert + assert result is None + assert mock_client.return_value.unsubscribe.call_args_list == [ + call(SubscriptionArn=subscription_arn) + ] + assert ( + f"Unable to remove subscription '{subscription_arn}': status code was 400" + ) in caplog.text + + def test_get_invalid_sns_subscriptions(self, sns_subscription, sns_fanout): + # Arrange + client_mock = Mock() + sns_fanout.get_client = client_mock + + # Mock paginator + mock_paginate = Mock() + sns_fanout.get_client().get_paginator.return_value = mock_paginate + mock_paginate.paginate.return_value = iter( + [ + { + "Subscriptions": [ + {"SubscriptionArn": "arn1"}, + {"SubscriptionArn": "arn2"}, + ] + }, + {"Subscriptions": [{"SubscriptionArn": "arn3"}]}, + ] + ) + + # Mock filter + sns_subscription._filter_sns_subscription_response = Mock( + side_effect=[["arn3"], ["arn2"]] + ) + + sns_topic_arn = "arn:aws:sns:us-west-2:123456789012:my-topic" + + # Act + result = sns_subscription._get_invalid_sns_subscriptions(sns_topic_arn) + + # Assert + assert result == ["arn3", "arn2"] + assert mock_paginate.paginate.call_args_list == [call(TopicArn=sns_topic_arn)] + assert sns_subscription._filter_sns_subscription_response.call_args_list == [ + call([{"SubscriptionArn": "arn1"}, {"SubscriptionArn": "arn2"}]), + call([{"SubscriptionArn": "arn3"}]), + ] + + def test_get_invalid_sns_subscriptions_empty(self, sns_subscription, sns_fanout): + # Arrange + client_mock = Mock() + sns_fanout.get_client = client_mock + + # Mock paginator + mock_paginate = Mock() + sns_fanout.get_client().get_paginator.return_value = mock_paginate + mock_paginate.paginate.return_value = iter( + [ + {"Subscriptions": []}, + {"Subscriptions": []}, + ] + ) + + # Mock filter + sns_subscription._filter_sns_subscription_response = Mock(return_value=[]) + + # Act + result = sns_subscription._get_invalid_sns_subscriptions( + "arn:aws:sns:us-west-2:123456789012:my-topic" + ) + + # Assert + assert result == [] + + def test_get_invalid_sns_subscriptions_no_subscriptions_key( + self, sns_subscription, sns_fanout + ): + # Arrange + client_mock = Mock() + sns_fanout.get_client = client_mock + + # Mock paginator + mock_paginate = Mock() + sns_fanout.get_client().get_paginator.return_value = mock_paginate + mock_paginate.paginate.return_value = iter( + [ + {}, + {"Subscriptions": [{"SubscriptionArn": "arn1"}]}, + ] + ) + + # Mock filter + sns_subscription._filter_sns_subscription_response = Mock(return_value=[]) + + sns_topic_arn = "arn:aws:sns:us-west-2:123456789012:my-topic" + + sns_subscription._filter_sns_subscription_response = Mock( + side_effect=[[], ["arn1"]] + ) + + # Act + result = sns_subscription._get_invalid_sns_subscriptions(sns_topic_arn) + + # Assert + assert result == ["arn1"] + assert sns_subscription._filter_sns_subscription_response.call_args_list == [ + call(None), + call([{"SubscriptionArn": "arn1"}]), + ] + + @pytest.mark.parametrize("value", [None, "", []]) + def test__filter_sns_subscription_response_nothing_provided( + self, value, sns_subscription + ): + # Act & Assert + assert sns_subscription._filter_sns_subscription_response(value) == [] + + def test__filter_sns_subscription_response(self, sns_subscription, channel_fixture): + # Arrange + subscriptions = [ + { + "Protocol": "SqS", + "Endpoint": "https://sqs.us-west-2.amazonaws.com/123456789012/my-queue-1", + "SubscriptionArn": "arn-1", + }, # Test case-sensitivity + { + "Protocol": "sqs", + "Endpoint": "https://sqs.us-west-2.amazonaws.com/123456789012/my-queue-2", + "SubscriptionArn": "arn-2", + }, # Test case-sensitivity + { + "Protocol": "Lambda", + "Endpoint": "arn:aws:lambda:us-west-2:123456789012:function:my-lambda-function", + "SubscriptionArn": "lambda-arn-1", + }, # This should be filtered out + { + "Protocol": "sqs", + "Endpoint": "https://sqs.us-west-2.amazonaws.com/123456789012/my-queue-3", + "SubscriptionArn": "arn-3", + }, + { + "Protocol": "sqs", + "Endpoint": "https://sqs.us-west-2.amazonaws.com/123456789012/my-queue-4", + "SubscriptionArn": "arn-4", + }, + { + "Protocol": "sqs", + "Endpoint": "https://sqs.us-west-2.amazonaws.com/123456789012/my-queue-5", + "SubscriptionArn": "arn-5", + }, + { + "Protocol": "SQS", + "Endpoint": "https://sqs.us-west-2.amazonaws.com/123456789012/my-queue-6", + "SubscriptionArn": "arn-6", + }, + ] + sqs_mock = Mock() + channel_fixture.sqs = sqs_mock + + # Setup errors on queues 2,4 and 5 + sqs_mock.return_value.get_queue_url.side_effect = [ + None, # queue-1 + ClientError( + error_response={"Error": {"Code": "QueueDoesNotExist"}}, + operation_name="GetQueueUrl", + ), # queue-2 + None, # queue-3 + ClientError( + error_response={"Error": {"Code": "NonExistentQueue"}}, + operation_name="GetQueueUrl", + ), # queue-4 + ClientError( + error_response={"Error": {"Code": "NonExistentQueue"}}, + operation_name="GetQueueUrl", + ), # queue-5 + None, # queue-6 + ] + + # Act + result = sns_subscription._filter_sns_subscription_response(subscriptions) + + # Assert + assert result == ["arn-2", "arn-4", "arn-5"] + assert sqs_mock.return_value.get_queue_url.call_args_list == [ + call(QueueName="//sqs.us-west-2.amazonaws.com/123456789012/my-queue-1"), + call(QueueName="//sqs.us-west-2.amazonaws.com/123456789012/my-queue-2"), + call(QueueName="//sqs.us-west-2.amazonaws.com/123456789012/my-queue-3"), + call(QueueName="//sqs.us-west-2.amazonaws.com/123456789012/my-queue-4"), + call(QueueName="//sqs.us-west-2.amazonaws.com/123456789012/my-queue-5"), + call(QueueName="//sqs.us-west-2.amazonaws.com/123456789012/my-queue-6"), + ] + + @pytest.mark.parametrize( + "exc_type", [ClientError, ValueError, Exception, IndexError, KeyError] + ) + def test__filter_sns_subscription_response_exceptions( + self, exc_type, sns_subscription, channel_fixture + ): + # Arrange + subscriptions = [ + { + "Protocol": "SqS", + "Endpoint": "https://sqs.us-west-2.amazonaws.com/123456789012/my-queue-1", + "SubscriptionArn": "arn-1", + }, # Test case-sensitivity + { + "Protocol": "sqs", + "Endpoint": "https://sqs.us-west-2.amazonaws.com/123456789012/my-queue-2", + "SubscriptionArn": "arn-2", + }, # Test case-sensitivity + { + "Protocol": "Lambda", + "Endpoint": "arn:aws:lambda:us-west-2:123456789012:function:my-lambda-function", + "SubscriptionArn": "lambda-arn-1", + }, # This should be filtered out + { + "Protocol": "sqs", + "Endpoint": "https://sqs.us-west-2.amazonaws.com/123456789012/my-queue-3", + "SubscriptionArn": "arn-3", + }, + { + "Protocol": "sqs", + "Endpoint": "https://sqs.us-west-2.amazonaws.com/123456789012/my-queue-4", + "SubscriptionArn": "arn-4", + }, + { + "Protocol": "sqs", + "Endpoint": "https://sqs.us-west-2.amazonaws.com/123456789012/my-queue-5", + "SubscriptionArn": "arn-5", + }, + { + "Protocol": "SQS", + "Endpoint": "https://sqs.us-west-2.amazonaws.com/123456789012/my-queue-6", + "SubscriptionArn": "arn-6", + }, + ] + sqs_mock = Mock() + channel_fixture.sqs = sqs_mock + + # Build exception + if exc_type == ClientError: + exc = ClientError( + error_response={"Error": {"Code": "ThisIsATest"}}, + operation_name="GetQueueUrl", + ) + else: + exc = exc_type("This is a test exception") + + sqs_mock.return_value.get_queue_url.side_effect = [ + None, # queue-1 + exc, # queue-2 + None, # queue-3 + ClientError( + error_response={"Error": {"Code": "NonExistentQueue"}}, + operation_name="GetQueueUrl", + ), # queue-4 + ClientError( + error_response={"Error": {"Code": "NonExistentQueue"}}, + operation_name="GetQueueUrl", + ), # queue-5 + None, # queue-6 + ] + + # Act & Assert + with pytest.raises(exc_type): + sns_subscription._filter_sns_subscription_response(subscriptions) + + def test_get_queue_arn_in_cache(self, sns_subscription, sns_fanout): + # Arrange + sns_subscription._queue_arn_cache = { + "my_queue": "arn:aws:sqs:us-west-2:123456789012:my-queue", + "my-queue-2": "arn:aws:sqs:us-west-2:123456789012:my-queue-2", + "my-queue-3": "arn:aws:sqs:us-west-2:123456789012:my-queue-3", + } + + chan_mock = Mock() + sns_fanout.channel = chan_mock + + # Act + result = sns_subscription._get_queue_arn("my-queue-2") + + # Assert + assert result == "arn:aws:sqs:us-west-2:123456789012:my-queue-2" + assert chan_mock._resolve_queue_url.call_count == 0 + + def test_get_queue_arn_lookup_success(self, sns_subscription, sns_fanout): + # Arrange + sns_subscription._queue_arn_cache = { + "my_queue": "arn:aws:sqs:us-west-2:123456789012:my-queue", + "my-queue-2": "arn:aws:sqs:us-west-2:123456789012:my-queue-2", + "my-queue-3": "arn:aws:sqs:us-west-2:123456789012:my-queue-3", + } + queue_arn = "arn:aws:sqs:us-west-2:123456789012:my-queue-4" + queue_url = "https://sqs.us-west-2.amazonaws.com/123456789012/my-queue-4" + + chan_mock = Mock() + sns_fanout.channel = chan_mock + chan_mock._resolve_queue_url.return_value = queue_url + chan_mock.sqs.return_value.get_queue_attributes.return_value = { + "ResponseMetadata": {"HTTPStatusCode": 200}, + "Attributes": {"QueueArn": queue_arn}, + } + assert "my-queue-4" not in sns_subscription._queue_arn_cache + + # Act + result = sns_subscription._get_queue_arn("my-queue-4") + + # Assert + assert result == queue_arn + assert chan_mock._resolve_queue_url.call_args_list == [call("my-queue-4")] + assert chan_mock.sqs.return_value.get_queue_attributes.call_args_list == [ + call(QueueUrl=queue_url, AttributeNames=["QueueArn"]) + ] + assert sns_subscription._queue_arn_cache["my-queue-4"] == queue_arn + + def test_get_queue_arn_lookup_failure(self, sns_subscription, sns_fanout): + # Arrange + sns_subscription._queue_arn_cache = { + "my_queue": "arn:aws:sqs:us-west-2:123456789012:my-queue", + "my-queue-2": "arn:aws:sqs:us-west-2:123456789012:my-queue-2", + "my-queue-3": "arn:aws:sqs:us-west-2:123456789012:my-queue-3", + } + queue_arn = "arn:aws:sqs:us-west-2:123456789012:my-queue-4" + queue_url = "https://sqs.us-west-2.amazonaws.com/123456789012/my-queue-4" + + chan_mock = Mock() + sns_fanout.channel = chan_mock + chan_mock._resolve_queue_url.return_value = queue_url + chan_mock.sqs.return_value.get_queue_attributes.return_value = { + "ResponseMetadata": {"HTTPStatusCode": 500}, + "Attributes": {"QueueArn": queue_arn}, + } + assert "my-queue-4" not in sns_subscription._queue_arn_cache + + # Act & assert + with pytest.raises( + KombuError, + match="Unable to get ARN for SQS queue 'my-queue-4': status code was '500'", + ): + sns_subscription._get_queue_arn("my-queue-4")