Skip to content

feat: update cooldown handling to support async operations #2823

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ These changes are available on the `master` branch, but have not yet been releas
([#2747](https://github.com/Pycord-Development/pycord/pull/2747))
- Added `discord.Interaction.created_at`.
([#2801](https://github.com/Pycord-Development/pycord/pull/2801))
- Added support for asynchronous functions in dynamic cooldowns.
([#2823](https://github.com/Pycord-Development/pycord/pull/2823))

### Fixed

Expand Down
12 changes: 5 additions & 7 deletions discord/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,10 +333,10 @@ def guild_only(self, value: bool) -> None:
InteractionContextType.private_channel,
}

def _prepare_cooldowns(self, ctx: ApplicationContext):
async def _prepare_cooldowns(self, ctx: ApplicationContext):
if self._buckets.valid:
current = datetime.datetime.now().timestamp()
bucket = self._buckets.get_bucket(ctx, current) # type: ignore # ctx instead of non-existent message
bucket = await self._buckets.get_bucket(ctx, current)

if bucket is not None:
retry_after = bucket.update_rate_limit(current)
Expand All @@ -356,11 +356,9 @@ async def prepare(self, ctx: ApplicationContext) -> None:
)

if self._max_concurrency is not None:
# For this application, context can be duck-typed as a Message
await self._max_concurrency.acquire(ctx) # type: ignore # ctx instead of non-existent message

await self._max_concurrency.acquire(ctx)
try:
self._prepare_cooldowns(ctx)
await self._prepare_cooldowns(ctx)
await self.call_before_hooks(ctx)
except:
if self._max_concurrency is not None:
Expand Down Expand Up @@ -400,7 +398,7 @@ def reset_cooldown(self, ctx: ApplicationContext) -> None:
The invocation context to reset the cooldown under.
"""
if self._buckets.valid:
bucket = self._buckets.get_bucket(ctx) # type: ignore # ctx instead of non-existent message
bucket = self._buckets.get_bucket(ctx)
bucket.reset()

def get_cooldown_retry_after(self, ctx: ApplicationContext) -> float:
Expand Down
6 changes: 5 additions & 1 deletion discord/ext/commands/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,11 @@ def me(self) -> Member | ClientUser:
message contexts, or when :meth:`Intents.guilds` is absent.
"""
# bot.user will never be None at this point.
return self.guild.me if self.guild is not None and self.guild.me is not None else self.bot.user # type: ignore
return (
self.guild.me
if self.guild is not None and self.guild.me is not None
else self.bot.user
) # type: ignore

@property
def voice_client(self) -> VoiceProtocol | None:
Expand Down
93 changes: 58 additions & 35 deletions discord/ext/commands/cooldowns.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
from __future__ import annotations

import asyncio
import inspect
import time
from collections import deque
from typing import TYPE_CHECKING, Any, Callable, Deque, TypeVar
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Deque, TypeVar

import discord.abc
from discord.enums import Enum
Expand All @@ -37,6 +38,8 @@
from .errors import MaxConcurrencyReached

if TYPE_CHECKING:
from ...commands import ApplicationContext
from ...ext.commands import Context
from ...message import Message

__all__ = (
Expand All @@ -60,31 +63,35 @@ class BucketType(Enum):
category = 5
role = 6

def get_key(self, msg: Message) -> Any:
def get_key(self, ctx: Context | ApplicationContext) -> Any:
if self is BucketType.user:
return msg.author.id
return ctx.author.id
elif self is BucketType.guild:
return (msg.guild or msg.author).id
return (ctx.guild or ctx.author).id
elif self is BucketType.channel:
return msg.channel.id
return ctx.channel.id
elif self is BucketType.member:
return (msg.guild and msg.guild.id), msg.author.id
return (ctx.guild and ctx.guild.id), ctx.author.id
elif self is BucketType.category:
return (
msg.channel.category.id
if isinstance(msg.channel, discord.abc.GuildChannel)
and msg.channel.category
else msg.channel.id
ctx.channel.category.id
if isinstance(ctx.channel, discord.abc.GuildChannel)
and ctx.channel.category
else ctx.channel.id
)
elif self is BucketType.role:
# we return the channel id of a private-channel as there are only roles in guilds
# and that yields the same result as for a guild with only the @everyone role
# NOTE: PrivateChannel doesn't actually have an id attribute, but we assume we are
# receiving a DMChannel or GroupChannel which inherit from PrivateChannel and do
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore
return (
ctx.channel
if isinstance(ctx.channel, PrivateChannel)
else ctx.author.top_role
).id # type: ignore

def __call__(self, msg: Message) -> Any:
return self.get_key(msg)
def __call__(self, ctx: Context | ApplicationContext) -> Any:
return self.get_key(ctx)


class Cooldown:
Expand Down Expand Up @@ -208,14 +215,14 @@ class CooldownMapping:
def __init__(
self,
original: Cooldown | None,
type: Callable[[Message], Any],
type: Callable[[Context | ApplicationContext], Any],
) -> None:
if not callable(type):
raise TypeError("Cooldown type must be a BucketType or callable")

self._cache: dict[Any, Cooldown] = {}
self._cooldown: Cooldown | None = original
self._type: Callable[[Message], Any] = type
self._type: Callable[[Context | ApplicationContext], Any] = type

def copy(self) -> CooldownMapping:
ret = CooldownMapping(self._cooldown, self._type)
Expand All @@ -227,15 +234,15 @@ def valid(self) -> bool:
return self._cooldown is not None

@property
def type(self) -> Callable[[Message], Any]:
def type(self) -> Callable[[Context | ApplicationContext], Any]:
return self._type

@classmethod
def from_cooldown(cls: type[C], rate, per, type) -> C:
return cls(Cooldown(rate, per), type)

def _bucket_key(self, msg: Message) -> Any:
return self._type(msg)
def _bucket_key(self, ctx: Context | ApplicationContext) -> Any:
return self._type(ctx)

def _verify_cache_integrity(self, current: float | None = None) -> None:
# we want to delete all cache objects that haven't been used
Expand All @@ -246,37 +253,45 @@ def _verify_cache_integrity(self, current: float | None = None) -> None:
for k in dead_keys:
del self._cache[k]

def create_bucket(self, message: Message) -> Cooldown:
async def create_bucket(self, ctx: Context | ApplicationContext) -> Cooldown:
return self._cooldown.copy() # type: ignore

def get_bucket(self, message: Message, current: float | None = None) -> Cooldown:
async def get_bucket(
self, ctx: Context | ApplicationContext, current: float | None = None
) -> Cooldown:
if self._type is BucketType.default:
return self._cooldown # type: ignore

self._verify_cache_integrity(current)
key = self._bucket_key(message)
key = self._bucket_key(ctx)
if key not in self._cache:
bucket = self.create_bucket(message)
bucket = await self.create_bucket(ctx)
if bucket is not None:
self._cache[key] = bucket
else:
bucket = self._cache[key]

return bucket

def update_rate_limit(
self, message: Message, current: float | None = None
async def update_rate_limit(
self, ctx: Context | ApplicationContext, current: float | None = None
) -> float | None:
bucket = self.get_bucket(message, current)
bucket = await self.get_bucket(ctx, current)
return bucket.update_rate_limit(current)


class DynamicCooldownMapping(CooldownMapping):
def __init__(
self, factory: Callable[[Message], Cooldown], type: Callable[[Message], Any]
self,
factory: Callable[
[Context | ApplicationContext], Cooldown | Awaitable[Cooldown]
],
type: Callable[[Context | ApplicationContext], Any],
) -> None:
super().__init__(None, type)
self._factory: Callable[[Message], Cooldown] = factory
self._factory: Callable[
[Context | ApplicationContext], Cooldown | Awaitable[Cooldown]
] = factory

def copy(self) -> DynamicCooldownMapping:
ret = DynamicCooldownMapping(self._factory, self._type)
Expand All @@ -287,8 +302,16 @@ def copy(self) -> DynamicCooldownMapping:
def valid(self) -> bool:
return True

def create_bucket(self, message: Message) -> Cooldown:
return self._factory(message)
async def create_bucket(self, ctx: Context | ApplicationContext) -> Cooldown:
from ...ext.commands import Context

if isinstance(ctx, Context):
result = self._factory(ctx.message)
else:
result = self._factory(ctx)
if inspect.isawaitable(result):
return await result
return result


class _Semaphore:
Expand Down Expand Up @@ -376,11 +399,11 @@ def __repr__(self) -> str:
f"<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>"
)

def get_key(self, message: Message) -> Any:
return self.per.get_key(message)
def get_key(self, ctx: Context | ApplicationContext) -> Any:
return self.per.get_key(ctx)

async def acquire(self, message: Message) -> None:
key = self.get_key(message)
async def acquire(self, ctx: Context | ApplicationContext) -> None:
key = self.get_key(ctx)

try:
sem = self._mapping[key]
Expand All @@ -391,10 +414,10 @@ async def acquire(self, message: Message) -> None:
if not acquired:
raise MaxConcurrencyReached(self.number, self.per)

async def release(self, message: Message) -> None:
async def release(self, ctx: Context | ApplicationContext) -> None:
# Technically there's no reason for this function to be async
# But it might be more useful in the future
key = self.get_key(message)
key = self.get_key(ctx)

try:
sem = self._mapping[key]
Expand Down
56 changes: 24 additions & 32 deletions discord/ext/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Generator,
Generic,
Expand Down Expand Up @@ -69,6 +70,7 @@
if TYPE_CHECKING:
from typing_extensions import Concatenate, ParamSpec, TypeGuard

from discord import ApplicationContext
from discord.message import Message

from ._types import Check, Coro, CoroFunc, Error, Hook
Expand Down Expand Up @@ -397,7 +399,9 @@ def __init__(

# bandaid for the fact that sometimes parent can be the bot instance
parent = kwargs.get("parent")
self.parent: GroupMixin | None = parent if isinstance(parent, _BaseCommand) else None # type: ignore
self.parent: GroupMixin | None = (
parent if isinstance(parent, _BaseCommand) else None
) # type: ignore

self._before_invoke: Hook | None = None
try:
Expand Down Expand Up @@ -850,11 +854,11 @@ async def call_after_hooks(self, ctx: Context) -> None:
if hook is not None:
await hook(ctx)

def _prepare_cooldowns(self, ctx: Context) -> None:
async def _prepare_cooldowns(self, ctx: Context) -> None:
if self._buckets.valid:
dt = ctx.message.edited_at or ctx.message.created_at
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
bucket = self._buckets.get_bucket(ctx.message, current)
bucket = await self._buckets.get_bucket(ctx, current)
if bucket is not None:
retry_after = bucket.update_rate_limit(current)
if retry_after:
Expand All @@ -869,15 +873,14 @@ async def prepare(self, ctx: Context) -> None:
)

if self._max_concurrency is not None:
# For this application, context can be duck-typed as a Message
await self._max_concurrency.acquire(ctx) # type: ignore
await self._max_concurrency.acquire(ctx)

try:
if self.cooldown_after_parsing:
await self._parse_arguments(ctx)
self._prepare_cooldowns(ctx)
await self._prepare_cooldowns(ctx)
else:
self._prepare_cooldowns(ctx)
await self._prepare_cooldowns(ctx)
await self._parse_arguments(ctx)

await self.call_before_hooks(ctx)
Expand Down Expand Up @@ -1204,7 +1207,9 @@ async def can_run(self, ctx: Context) -> bool:
# since we have no checks, then we just return True.
return True

return await discord.utils.async_all(predicate(ctx) for predicate in predicates) # type: ignore
return await discord.utils.async_all(
predicate(ctx) for predicate in predicates
) # type: ignore
finally:
ctx.command = original

Expand Down Expand Up @@ -2353,36 +2358,23 @@ def decorator(func: Command | CoroFunc) -> Command | CoroFunc:


def dynamic_cooldown(
cooldown: BucketType | Callable[[Message], Any],
cooldown: Callable[
[Context | ApplicationContext], Cooldown | Awaitable[Cooldown] | None
],
type: BucketType = BucketType.default,
) -> Callable[[T], T]:
"""A decorator that adds a dynamic cooldown to a command

This differs from :func:`.cooldown` in that it takes a function that
accepts a single parameter of type :class:`.discord.Message` and must
return a :class:`.Cooldown` or ``None``. If ``None`` is returned then
that cooldown is effectively bypassed.

A cooldown allows a command to only be used a specific amount
of times in a specific time frame. These cooldowns can be based
either on a per-guild, per-channel, per-user, per-role or global basis.
Denoted by the third argument of ``type`` which must be of enum
type :class:`.BucketType`.

If a cooldown is triggered, then :exc:`.CommandOnCooldown` is triggered in
:func:`.on_command_error` and the local error handler.

A command can only have a single cooldown.
"""A decorator that adds a dynamic cooldown to a command.

.. versionadded:: 2.0
This supports both sync and async cooldown factories and accepts either
a :class:`discord.Message` or :class:`discord.ApplicationContext`.

Parameters
----------
cooldown: Callable[[:class:`.discord.Message`], Optional[:class:`.Cooldown`]]
A function that takes a message and returns a cooldown that will
apply to this invocation or ``None`` if the cooldown should be bypassed.
type: :class:`.BucketType`
The type of cooldown to have.
cooldown: Callable[[Union[Message, ApplicationContext]], Union[Cooldown, Awaitable[Cooldown], None]]
A function that takes a message or context and returns a cooldown
to apply for that invocation or ``None`` to bypass.
type: :class:`BucketType`
The cooldown bucket type (e.g. per-user, per-channel).
"""
if not callable(cooldown):
raise TypeError("A callable must be provided")
Expand Down