diff --git a/CHANGELOG.md b/CHANGELOG.md index fdfdc6d896..e1adee42a5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,6 +55,8 @@ These changes are available on the `master` branch, but have not yet been releas ([#2714](https://github.com/Pycord-Development/pycord/pull/2714)) - Added the ability to pass a `datetime.time` object to `format_dt`. ([#2747](https://github.com/Pycord-Development/pycord/pull/2747)) +- Added `Guild.get_or_fetch()` and `Client.get_or_fetch()` shortcut methods. + ([#2776](https://github.com/Pycord-Development/pycord/pull/2776)) - Added `discord.Interaction.created_at`. ([#2801](https://github.com/Pycord-Development/pycord/pull/2801)) @@ -147,6 +149,9 @@ These changes are available on the `master` branch, but have not yet been releas ([#2501](https://github.com/Pycord-Development/pycord/pull/2501)) - Deprecated `Interaction.cached_channel` in favor of `Interaction.channel`. ([#2658](https://github.com/Pycord-Development/pycord/pull/2658)) +- Deprecated `utils.get_or_fetch(attr, id)` in favor of + `utils.get_or_fetch(object_type, object_id)`. + ([#2776](https://github.com/Pycord-Development/pycord/pull/2776)) ### Removed diff --git a/discord/client.py b/discord/client.py index 6768d4a660..d8e3776340 100644 --- a/discord/client.py +++ b/discord/client.py @@ -31,7 +31,15 @@ import sys import traceback from types import TracebackType -from typing import TYPE_CHECKING, Any, Callable, Coroutine, Generator, Sequence, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Coroutine, + Generator, + Sequence, + TypeVar, +) import aiohttp @@ -60,17 +68,25 @@ from .threads import Thread from .ui.view import View from .user import ClientUser, User -from .utils import MISSING +from .utils import _FETCHABLE, MISSING from .voice_client import VoiceClient from .webhook import Webhook from .widget import Widget if TYPE_CHECKING: from .abc import GuildChannel, PrivateChannel, Snowflake, SnowflakeTime - from .channel import DMChannel + from .channel import ( + CategoryChannel, + DMChannel, + ForumChannel, + StageChannel, + TextChannel, + VoiceChannel, + ) from .member import Member from .message import Message from .poll import Poll + from .threads import Thread, ThreadMember from .voice_client import VoiceProtocol __all__ = ("Client",) @@ -1113,7 +1129,12 @@ def get_all_members(self) -> Generator[Member]: for guild in self.guilds: yield from guild.members - async def get_or_fetch_user(self, id: int, /) -> User | None: + @utils.deprecated( + instead="Client.get_or_fetch(User, id)", + since="2.7", + removed="3.0", + ) + async def get_or_fetch_user(self, id: int, /) -> User | None: # TODO: Remove in 3.0 """|coro| Looks up a user in the user cache or fetches if not found. @@ -1129,7 +1150,43 @@ async def get_or_fetch_user(self, id: int, /) -> User | None: The user or ``None`` if not found. """ - return await utils.get_or_fetch(obj=self, attr="user", id=id, default=None) + return await self.get_or_fetch(object_type=User, object_id=id, default=None) + + async def get_or_fetch( + self: Client, + object_type: type[_FETCHABLE], + object_id: int | None, + default: Any = MISSING, + ) -> _FETCHABLE | None: + """ + Shortcut method to get data from an object either by returning the cached version, or if it does not exist, attempting to fetch it from the API. + + Parameters + ---------- + object_type: Union[:class:`VoiceChannel`, :class:`TextChannel`, :class:`ForumChannel`, :class:`StageChannel`, :class:`CategoryChannel`, :class:`Thread`, :class:`User`, :class:`Guild`, :class:`GuildEmoji`, :class:`AppEmoji`] + Type of object to fetch or get. + object_id: :class:`int` + ID of object to get. + default : Any, optional + A default to return instead of raising if fetch fails. + + Returns + ------- + Optional[Union[:class:`VoiceChannel`, :class:`TextChannel`, :class:`ForumChannel`, :class:`StageChannel`, :class:`CategoryChannel`, :class:`Thread`, :class:`User`, :class:`Guild`, :class:`GuildEmoji`, :class:`AppEmoji`]] + The object of type that was specified or ``None`` if not found. + + Raises + ------ + :exc:`NotFound` + Invalid ID for the object + :exc:`HTTPException` + An error occurred fetching the object + :exc:`Forbidden` + You do not have permission to fetch the object + """ + return await utils.get_or_fetch( + obj=self, object_type=object_type, object_id=object_id, default=default + ) # listeners/waiters diff --git a/discord/guild.py b/discord/guild.py index 6a9d54537a..fb97bc1809 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -85,6 +85,7 @@ from .sticker import GuildSticker from .threads import Thread, ThreadMember from .user import User +from .utils import _FETCHABLE from .welcome_screen import WelcomeScreen, WelcomeScreenChannel from .widget import Widget @@ -863,6 +864,45 @@ def get_member(self, user_id: int, /) -> Member | None: """ return self._members.get(user_id) + async def get_or_fetch( + self: Guild, + object_type: type[_FETCHABLE], + object_id: int | None, + default: Any = MISSING, + ) -> _FETCHABLE | None: + """ + Shortcut method to get data from an object either by returning the cached version, or if it does not exist, attempting to fetch it from the API. + + Parameters + ---------- + object_type: Union[:class:`VoiceChannel`, :class:`TextChannel`, :class:`ForumChannel`, :class:`StageChannel`, :class:`CategoryChannel`, :class:`Thread`, :class:`Role`, :class:`Member`, :class:`GuildEmoji`] + Type of object to fetch or get. + + object_id: :class:`int` + ID of object to get. + + default : Any, optional + A default to return instead of raising if fetch fails. + + Returns + ------- + + Optional[Union[:class:`VoiceChannel`, :class:`TextChannel`, :class:`ForumChannel`, :class:`StageChannel`, :class:`CategoryChannel`, :class:`Thread`, :class:`Role`, :class:`Member`, :class:`GuildEmoji`]] + The object of type that was specified or ``None`` if not found. + + Raises + ------ + :exc:`NotFound` + Invalid ID for the object + :exc:`HTTPException` + An error occurred fetching the object + :exc:`Forbidden` + You do not have permission to fetch the object + """ + return await utils.get_or_fetch( + obj=self, object_type=object_type, object_id=object_id, default=default + ) + @property def premium_subscribers(self) -> list[Member]: """A list of members who have "boosted" this guild.""" @@ -2664,6 +2704,26 @@ async def delete_sticker( """ await self._state.http.delete_guild_sticker(self.id, sticker.id, reason) + def get_emoji(self, emoji_id: int, /) -> GuildEmoji | None: + """Returns an emoji with the given ID. + + .. versionadded:: 2.7 + + Parameters + ---------- + emoji_id: int + The ID to search for. + + Returns + ------- + Optional[:class:`Emoji`] + The returned Emoji or ``None`` if not found. + """ + emoji = self._state.get_emoji(emoji_id) + if emoji and emoji.guild == self: + return emoji + return None + async def fetch_emojis(self) -> list[GuildEmoji]: r"""|coro| diff --git a/discord/utils.py b/discord/utils.py index b509162cf0..7ec7165694 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -62,6 +62,23 @@ overload, ) +if TYPE_CHECKING: + from discord import ( + Client, + VoiceChannel, + TextChannel, + ForumChannel, + StageChannel, + CategoryChannel, + Thread, + Member, + User, + Guild, + Role, + GuildEmoji, + AppEmoji, + ) + from .errors import HTTPException, InvalidArgument try: @@ -97,6 +114,7 @@ "generate_snowflake", "basic_autocomplete", "filter_params", + "MISSING", ) DISCORD_EPOCH = 1420070400000 @@ -567,64 +585,163 @@ def get(iterable: Iterable[T], **attrs: Any) -> T | None: return None -async def get_or_fetch(obj, attr: str, id: int, *, default: Any = MISSING) -> Any: - """|coro| +_FETCHABLE = TypeVar( + "_FETCHABLE", + bound="VoiceChannel | TextChannel | ForumChannel | StageChannel | CategoryChannel | Thread | Member | User | Guild | Role | GuildEmoji | AppEmoji", +) - Attempts to get an attribute from the object in cache. If it fails, it will attempt to fetch it. - If the fetch also fails, an error will be raised. + +async def get_or_fetch( + obj: Guild | Client, + object_type: type[_FETCHABLE] = MISSING, + object_id: int | None = MISSING, + default: Any = MISSING, + attr: str = MISSING, + id: int = MISSING, +) -> _FETCHABLE | None: # TODO: Remove in 3.0 the arguments attr and id + """ + Shortcut method to get data from an object either by returning the cached version, or if it does not exist, attempting to fetch it from the API. Parameters ---------- - obj: Any - The object to use the get or fetch methods in - attr: :class:`str` - The attribute to get or fetch. Note the object must have both a ``get_`` and ``fetch_`` method for this attribute. - id: :class:`int` - The ID of the object - default: Any - The default value to return if the object is not found, instead of raising an error. + obj : Guild | Client + The object to operate on. + object_type: Union[:class:`VoiceChannel`, :class:`TextChannel`, :class:`ForumChannel`, :class:`StageChannel`, :class:`CategoryChannel`, :class:`Thread`, :class:`User`, :class:`Guild`, :class:`Role`, :class:`Member`, :class:`GuildEmoji`, :class:`AppEmoji`] + Type of object to fetch or get. + + object_id: :class:`int` + ID of object to get. + + default : Any, optional + A default to return instead of raising if fetch fails. Returns ------- - Any - The object found or the default value. + + Optional[Union[:class:`VoiceChannel`, :class:`TextChannel`, :class:`ForumChannel`, :class:`StageChannel`, :class:`CategoryChannel`, :class:`Thread`, :class:`User`, :class:`Guild`, :class:`Role`, :class:`Member`, :class:`GuildEmoji`, :class:`AppEmoji`]] + The object of type that was specified or ``None`` if not found. Raises ------ - :exc:`AttributeError` - The object is missing a ``get_`` or ``fetch_`` method :exc:`NotFound` Invalid ID for the object :exc:`HTTPException` An error occurred fetching the object :exc:`Forbidden` You do not have permission to fetch the object + """ + from discord import AppEmoji, Client, Guild, Member, Role, User, abc, emoji - Examples - -------- + if object_id is None: + return None + + string_to_type = { + "channel": abc.GuildChannel, + "member": Member, + "user": User, + "guild": Guild, + "emoji": emoji._EmojiTag, + "appemoji": AppEmoji, + "role": Role, + } + + if attr is not MISSING or id is not MISSING or isinstance(object_type, str): + warn_deprecated( + name="get_or_fetch(obj, attr='type', id=...)", + instead="get_or_fetch(obj, object_type=Type, object_id=...)", + since="2.7", + ) - Getting a guild from a guild ID: :: + deprecated_attr = attr if attr is not MISSING else object_type + deprecated_id = id if id is not MISSING else object_id + + if isinstance(deprecated_attr, str): + mapped_type = string_to_type.get(deprecated_attr.lower()) + if mapped_type is None: + raise InvalidArgument( + f"Unknown type string '{deprecated_attr}' used. Please use a valid object class like `discord.Member` instead." + ) + object_type = mapped_type + elif isinstance(deprecated_attr, type): + object_type = deprecated_attr + else: + raise TypeError( + f"Invalid `attr` or `object_type`: expected a string or class, got {type(deprecated_attr).__name__}." + ) - guild = await utils.get_or_fetch(client, 'guild', guild_id) + object_id = deprecated_id - Getting a channel from the guild. If the channel is not found, return None: :: + if object_type is MISSING or object_id is MISSING: + raise TypeError("required parameters: `object_type` and `object_id`.") - channel = await utils.get_or_fetch(guild, 'channel', channel_id, default=None) - """ - getter = getattr(obj, f"get_{attr}")(id) - if getter is None: - try: - getter = await getattr(obj, f"fetch_{attr}")(id) - except AttributeError: - getter = await getattr(obj, f"_fetch_{attr}")(id) - if getter is None: - raise ValueError(f"Could not find {attr} with id {id} on {obj}") - except (HTTPException, ValueError): - if default is not MISSING: - return default - else: - raise - return getter + if issubclass(object_type, (Member, User, Guild)): + attr = object_type.__name__.lower() + elif issubclass(object_type, emoji._EmojiTag): + attr = "emoji" + elif issubclass(object_type, Role): + attr = "role" + elif issubclass(object_type, abc.GuildChannel): + attr = "channel" + else: + raise InvalidArgument( + f"Class {object_type.__name__} cannot be used with discord.{type(obj).__name__}.get_or_fetch()" + ) + + if isinstance(obj, Guild) and object_type is User: + raise InvalidArgument( + "Guild cannot get_or_fetch discord.User. Use Client instead." + ) + elif isinstance(obj, Client) and object_type is Member: + raise InvalidArgument("Client cannot get_or_fetch Member. Use Guild instead.") + elif isinstance(obj, Client) and object_type is Role: + raise InvalidArgument("Client cannot get_or_fetch Role. Use Guild instead.") + elif isinstance(obj, Guild) and object_type is Guild: + raise InvalidArgument("Guild cannot get_or_fetch Guild. Use Client instead.") + + getter_fetcher_map = { + Member: ( + lambda obj, oid: obj.get_member(oid), + lambda obj, oid: obj.fetch_member(oid), + ), + Role: ( + lambda obj, oid: obj.get_role(oid), + lambda obj, oid: obj._fetch_role(oid), + ), + User: ( + lambda obj, oid: obj.get_user(oid), + lambda obj, oid: obj.fetch_user(oid), + ), + Guild: ( + lambda obj, oid: obj.get_guild(oid), + lambda obj, oid: obj.fetch_guild(oid), + ), + emoji._EmojiTag: ( + lambda obj, oid: obj.get_emoji(oid), + lambda obj, oid: obj.fetch_emoji(oid), + ), + abc.GuildChannel: ( + lambda obj, oid: obj.get_channel(oid), + lambda obj, oid: obj.fetch_channel(oid), + ), + } + try: + base_type = next( + base for base in getter_fetcher_map if issubclass(object_type, base) + ) + getter, fetcher = getter_fetcher_map[base_type] + except KeyError: + raise InvalidArgument(f"Unsupported object type: {object_type.__name__}") + + result = getter(obj, object_id) + if result is not None: + return result + + try: + return await fetcher(obj, object_id) + except (HTTPException, ValueError): + if default is not None: + return default + raise def _unique(iterable: Iterable[T]) -> list[T]: