diff --git a/CHANGELOG.md b/CHANGELOG.md index fdfdc6d896..404eb8abf5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,6 +57,8 @@ These changes are available on the `master` branch, but have not yet been releas ([#2747](https://github.com/Pycord-Development/pycord/pull/2747)) - Added `discord.Interaction.created_at`. ([#2801](https://github.com/Pycord-Development/pycord/pull/2801)) +- Added support for asynchronous functions in dynamic cooldowns. + ([#2823](https://github.com/Pycord-Development/pycord/pull/2823)) ### Fixed diff --git a/discord/commands/core.py b/discord/commands/core.py index 06996bcaa1..4a407fab43 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -333,10 +333,10 @@ def guild_only(self, value: bool) -> None: InteractionContextType.private_channel, } - def _prepare_cooldowns(self, ctx: ApplicationContext): + async def _prepare_cooldowns(self, ctx: ApplicationContext): if self._buckets.valid: current = datetime.datetime.now().timestamp() - bucket = self._buckets.get_bucket(ctx, current) # type: ignore # ctx instead of non-existent message + bucket = await self._buckets.get_bucket(ctx, current) if bucket is not None: retry_after = bucket.update_rate_limit(current) @@ -356,11 +356,9 @@ async def prepare(self, ctx: ApplicationContext) -> None: ) if self._max_concurrency is not None: - # For this application, context can be duck-typed as a Message - await self._max_concurrency.acquire(ctx) # type: ignore # ctx instead of non-existent message - + await self._max_concurrency.acquire(ctx) try: - self._prepare_cooldowns(ctx) + await self._prepare_cooldowns(ctx) await self.call_before_hooks(ctx) except: if self._max_concurrency is not None: @@ -400,7 +398,7 @@ def reset_cooldown(self, ctx: ApplicationContext) -> None: The invocation context to reset the cooldown under. """ if self._buckets.valid: - bucket = self._buckets.get_bucket(ctx) # type: ignore # ctx instead of non-existent message + bucket = self._buckets.get_bucket(ctx) bucket.reset() def get_cooldown_retry_after(self, ctx: ApplicationContext) -> float: diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index e2be74a21b..75773494ea 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -311,7 +311,11 @@ def me(self) -> Member | ClientUser: message contexts, or when :meth:`Intents.guilds` is absent. """ # bot.user will never be None at this point. - return self.guild.me if self.guild is not None and self.guild.me is not None else self.bot.user # type: ignore + return ( + self.guild.me + if self.guild is not None and self.guild.me is not None + else self.bot.user + ) # type: ignore @property def voice_client(self) -> VoiceProtocol | None: diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py index 6e58d37f7a..86b9fa588e 100644 --- a/discord/ext/commands/cooldowns.py +++ b/discord/ext/commands/cooldowns.py @@ -26,9 +26,10 @@ from __future__ import annotations import asyncio +import inspect import time from collections import deque -from typing import TYPE_CHECKING, Any, Callable, Deque, TypeVar +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Deque, TypeVar import discord.abc from discord.enums import Enum @@ -37,6 +38,8 @@ from .errors import MaxConcurrencyReached if TYPE_CHECKING: + from ...commands import ApplicationContext + from ...ext.commands import Context from ...message import Message __all__ = ( @@ -60,31 +63,35 @@ class BucketType(Enum): category = 5 role = 6 - def get_key(self, msg: Message) -> Any: + def get_key(self, ctx: Context | ApplicationContext) -> Any: if self is BucketType.user: - return msg.author.id + return ctx.author.id elif self is BucketType.guild: - return (msg.guild or msg.author).id + return (ctx.guild or ctx.author).id elif self is BucketType.channel: - return msg.channel.id + return ctx.channel.id elif self is BucketType.member: - return (msg.guild and msg.guild.id), msg.author.id + return (ctx.guild and ctx.guild.id), ctx.author.id elif self is BucketType.category: return ( - msg.channel.category.id - if isinstance(msg.channel, discord.abc.GuildChannel) - and msg.channel.category - else msg.channel.id + ctx.channel.category.id + if isinstance(ctx.channel, discord.abc.GuildChannel) + and ctx.channel.category + else ctx.channel.id ) elif self is BucketType.role: # we return the channel id of a private-channel as there are only roles in guilds # and that yields the same result as for a guild with only the @everyone role # NOTE: PrivateChannel doesn't actually have an id attribute, but we assume we are # receiving a DMChannel or GroupChannel which inherit from PrivateChannel and do - return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore + return ( + ctx.channel + if isinstance(ctx.channel, PrivateChannel) + else ctx.author.top_role + ).id # type: ignore - def __call__(self, msg: Message) -> Any: - return self.get_key(msg) + def __call__(self, ctx: Context | ApplicationContext) -> Any: + return self.get_key(ctx) class Cooldown: @@ -208,14 +215,14 @@ class CooldownMapping: def __init__( self, original: Cooldown | None, - type: Callable[[Message], Any], + type: Callable[[Context | ApplicationContext], Any], ) -> None: if not callable(type): raise TypeError("Cooldown type must be a BucketType or callable") self._cache: dict[Any, Cooldown] = {} self._cooldown: Cooldown | None = original - self._type: Callable[[Message], Any] = type + self._type: Callable[[Context | ApplicationContext], Any] = type def copy(self) -> CooldownMapping: ret = CooldownMapping(self._cooldown, self._type) @@ -227,15 +234,15 @@ def valid(self) -> bool: return self._cooldown is not None @property - def type(self) -> Callable[[Message], Any]: + def type(self) -> Callable[[Context | ApplicationContext], Any]: return self._type @classmethod def from_cooldown(cls: type[C], rate, per, type) -> C: return cls(Cooldown(rate, per), type) - def _bucket_key(self, msg: Message) -> Any: - return self._type(msg) + def _bucket_key(self, ctx: Context | ApplicationContext) -> Any: + return self._type(ctx) def _verify_cache_integrity(self, current: float | None = None) -> None: # we want to delete all cache objects that haven't been used @@ -246,17 +253,19 @@ def _verify_cache_integrity(self, current: float | None = None) -> None: for k in dead_keys: del self._cache[k] - def create_bucket(self, message: Message) -> Cooldown: + async def create_bucket(self, ctx: Context | ApplicationContext) -> Cooldown: return self._cooldown.copy() # type: ignore - def get_bucket(self, message: Message, current: float | None = None) -> Cooldown: + async def get_bucket( + self, ctx: Context | ApplicationContext, current: float | None = None + ) -> Cooldown: if self._type is BucketType.default: return self._cooldown # type: ignore self._verify_cache_integrity(current) - key = self._bucket_key(message) + key = self._bucket_key(ctx) if key not in self._cache: - bucket = self.create_bucket(message) + bucket = await self.create_bucket(ctx) if bucket is not None: self._cache[key] = bucket else: @@ -264,19 +273,25 @@ def get_bucket(self, message: Message, current: float | None = None) -> Cooldown return bucket - def update_rate_limit( - self, message: Message, current: float | None = None + async def update_rate_limit( + self, ctx: Context | ApplicationContext, current: float | None = None ) -> float | None: - bucket = self.get_bucket(message, current) + bucket = await self.get_bucket(ctx, current) return bucket.update_rate_limit(current) class DynamicCooldownMapping(CooldownMapping): def __init__( - self, factory: Callable[[Message], Cooldown], type: Callable[[Message], Any] + self, + factory: Callable[ + [Context | ApplicationContext], Cooldown | Awaitable[Cooldown] + ], + type: Callable[[Context | ApplicationContext], Any], ) -> None: super().__init__(None, type) - self._factory: Callable[[Message], Cooldown] = factory + self._factory: Callable[ + [Context | ApplicationContext], Cooldown | Awaitable[Cooldown] + ] = factory def copy(self) -> DynamicCooldownMapping: ret = DynamicCooldownMapping(self._factory, self._type) @@ -287,8 +302,16 @@ def copy(self) -> DynamicCooldownMapping: def valid(self) -> bool: return True - def create_bucket(self, message: Message) -> Cooldown: - return self._factory(message) + async def create_bucket(self, ctx: Context | ApplicationContext) -> Cooldown: + from ...ext.commands import Context + + if isinstance(ctx, Context): + result = self._factory(ctx.message) + else: + result = self._factory(ctx) + if inspect.isawaitable(result): + return await result + return result class _Semaphore: @@ -376,11 +399,11 @@ def __repr__(self) -> str: f"" ) - def get_key(self, message: Message) -> Any: - return self.per.get_key(message) + def get_key(self, ctx: Context | ApplicationContext) -> Any: + return self.per.get_key(ctx) - async def acquire(self, message: Message) -> None: - key = self.get_key(message) + async def acquire(self, ctx: Context | ApplicationContext) -> None: + key = self.get_key(ctx) try: sem = self._mapping[key] @@ -391,10 +414,10 @@ async def acquire(self, message: Message) -> None: if not acquired: raise MaxConcurrencyReached(self.number, self.per) - async def release(self, message: Message) -> None: + async def release(self, ctx: Context | ApplicationContext) -> None: # Technically there's no reason for this function to be async # But it might be more useful in the future - key = self.get_key(message) + key = self.get_key(ctx) try: sem = self._mapping[key] diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 6f6ef1dffa..2f5a325d70 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -33,6 +33,7 @@ from typing import ( TYPE_CHECKING, Any, + Awaitable, Callable, Generator, Generic, @@ -69,6 +70,7 @@ if TYPE_CHECKING: from typing_extensions import Concatenate, ParamSpec, TypeGuard + from discord import ApplicationContext from discord.message import Message from ._types import Check, Coro, CoroFunc, Error, Hook @@ -397,7 +399,9 @@ def __init__( # bandaid for the fact that sometimes parent can be the bot instance parent = kwargs.get("parent") - self.parent: GroupMixin | None = parent if isinstance(parent, _BaseCommand) else None # type: ignore + self.parent: GroupMixin | None = ( + parent if isinstance(parent, _BaseCommand) else None + ) # type: ignore self._before_invoke: Hook | None = None try: @@ -850,11 +854,11 @@ async def call_after_hooks(self, ctx: Context) -> None: if hook is not None: await hook(ctx) - def _prepare_cooldowns(self, ctx: Context) -> None: + async def _prepare_cooldowns(self, ctx: Context) -> None: if self._buckets.valid: dt = ctx.message.edited_at or ctx.message.created_at current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() - bucket = self._buckets.get_bucket(ctx.message, current) + bucket = await self._buckets.get_bucket(ctx, current) if bucket is not None: retry_after = bucket.update_rate_limit(current) if retry_after: @@ -869,15 +873,14 @@ async def prepare(self, ctx: Context) -> None: ) if self._max_concurrency is not None: - # For this application, context can be duck-typed as a Message - await self._max_concurrency.acquire(ctx) # type: ignore + await self._max_concurrency.acquire(ctx) try: if self.cooldown_after_parsing: await self._parse_arguments(ctx) - self._prepare_cooldowns(ctx) + await self._prepare_cooldowns(ctx) else: - self._prepare_cooldowns(ctx) + await self._prepare_cooldowns(ctx) await self._parse_arguments(ctx) await self.call_before_hooks(ctx) @@ -1204,7 +1207,9 @@ async def can_run(self, ctx: Context) -> bool: # since we have no checks, then we just return True. return True - return await discord.utils.async_all(predicate(ctx) for predicate in predicates) # type: ignore + return await discord.utils.async_all( + predicate(ctx) for predicate in predicates + ) # type: ignore finally: ctx.command = original @@ -2353,36 +2358,23 @@ def decorator(func: Command | CoroFunc) -> Command | CoroFunc: def dynamic_cooldown( - cooldown: BucketType | Callable[[Message], Any], + cooldown: Callable[ + [Context | ApplicationContext], Cooldown | Awaitable[Cooldown] | None + ], type: BucketType = BucketType.default, ) -> Callable[[T], T]: - """A decorator that adds a dynamic cooldown to a command - - This differs from :func:`.cooldown` in that it takes a function that - accepts a single parameter of type :class:`.discord.Message` and must - return a :class:`.Cooldown` or ``None``. If ``None`` is returned then - that cooldown is effectively bypassed. - - A cooldown allows a command to only be used a specific amount - of times in a specific time frame. These cooldowns can be based - either on a per-guild, per-channel, per-user, per-role or global basis. - Denoted by the third argument of ``type`` which must be of enum - type :class:`.BucketType`. - - If a cooldown is triggered, then :exc:`.CommandOnCooldown` is triggered in - :func:`.on_command_error` and the local error handler. - - A command can only have a single cooldown. + """A decorator that adds a dynamic cooldown to a command. - .. versionadded:: 2.0 + This supports both sync and async cooldown factories and accepts either + a :class:`discord.Message` or :class:`discord.ApplicationContext`. Parameters ---------- - cooldown: Callable[[:class:`.discord.Message`], Optional[:class:`.Cooldown`]] - A function that takes a message and returns a cooldown that will - apply to this invocation or ``None`` if the cooldown should be bypassed. - type: :class:`.BucketType` - The type of cooldown to have. + cooldown: Callable[[Union[Message, ApplicationContext]], Union[Cooldown, Awaitable[Cooldown], None]] + A function that takes a message or context and returns a cooldown + to apply for that invocation or ``None`` to bypass. + type: :class:`BucketType` + The cooldown bucket type (e.g. per-user, per-channel). """ if not callable(cooldown): raise TypeError("A callable must be provided")