Skip to content

Commit f57b7fa

Browse files
authored
Merge pull request #34 from eunwoo1104/command-as-model
Changed slash command to the model, added proper Cog support
2 parents de594f6 + 36c3c53 commit f57b7fa

File tree

7 files changed

+366
-71
lines changed

7 files changed

+366
-71
lines changed

README.md

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,28 +22,30 @@ async def _test(ctx: SlashContext):
2222
bot.run("discord_token")
2323
```
2424

25-
Cog (__Not Recommended__):
25+
Cog:
2626
```py
2727
import discord
2828
from discord.ext import commands
29+
from discord_slash import cog_ext
2930
from discord_slash import SlashCommand
3031
from discord_slash import SlashContext
3132

3233

3334
class Slash(commands.Cog):
3435
def __init__(self, bot):
36+
if not hasattr(bot, "slash"):
37+
# Creates new SlashCommand instance to bot if bot doesn't have.
38+
bot.slash = SlashCommand(bot, override_type=True)
3539
self.bot = bot
36-
self.slash = SlashCommand(bot, override_type=True)
37-
# Cog is only supported by commands ext, so just skip checking type.
38-
39-
# Make sure all commands should be inside `__init__`
40-
# or some other functions that can put commands.
41-
@self.slash.slash(name="test")
42-
async def _test(ctx: SlashContext):
43-
await ctx.send(content="Hello, World!")
40+
self.bot.slash.get_cog_commands(self)
4441

4542
def cog_unload(self):
46-
self.slash.remove()
43+
self.bot.slash.remove_cog_commands(self)
44+
45+
@cog_ext.cog_slash(name="test")
46+
async def _test(self, ctx: SlashContext):
47+
embed = discord.Embed(title="embed test")
48+
await ctx.send(content="test", embeds=[embed])
4749

4850

4951
def setup(bot):
@@ -72,5 +74,4 @@ Or you can ask at [Discussions](https://github.com/eunwoo1104/discord-py-slash-c
7274

7375
## TODO
7476
- Rewrite `http.py` and webhook part (Maybe use discord.py's webhook support?)
75-
- Properly support Cog
7677
- Try supporting most of the features supported by discord.py commands extension

discord_slash/client.py

Lines changed: 113 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class SlashCommand:
2626
:ivar auto_register: Whether to register commands automatically.
2727
:ivar has_listener: Whether discord client has listener add function.
2828
"""
29+
2930
def __init__(self,
3031
client: typing.Union[discord.Client, commands.Bot],
3132
auto_register: bool = False,
@@ -38,7 +39,8 @@ def __init__(self,
3839
self.auto_register = auto_register
3940
if self.auto_register:
4041
self._discord.loop.create_task(self.register_all_commands())
41-
if not isinstance(client, commands.Bot) and not isinstance(client, commands.AutoShardedBot) and not override_type:
42+
if not isinstance(client, commands.Bot) and not isinstance(client,
43+
commands.AutoShardedBot) and not override_type:
4244
self.logger.info("Detected discord.Client! Overriding on_socket_response.")
4345
self._discord.on_socket_response = self.on_socket_response
4446
self.has_listener = False
@@ -50,6 +52,9 @@ def remove(self):
5052
"""
5153
Removes :func:`on_socket_response` event listener from discord.py Client.
5254
55+
.. warning::
56+
This is deprecated and will be removed soon.
57+
5358
.. note::
5459
This only works if it is :class:`discord.ext.commands.Bot` or
5560
:class:`discord.ext.commands.AutoShardedBot`.
@@ -58,35 +63,107 @@ def remove(self):
5863
return
5964
self._discord.remove_listener(self.on_socket_response)
6065

66+
def get_cog_commands(self, cog: commands.Cog):
67+
"""
68+
Gets slash command from :class:`discord.ext.commands.Cog`.
69+
70+
:param cog: Cog that has slash commands.
71+
:type cog: discord.ext.commands.Cog
72+
"""
73+
func_list = [getattr(cog, x) for x in dir(cog)]
74+
res = [x for x in func_list if
75+
isinstance(x, model.CogCommandObject) or isinstance(x, model.CogSubcommandObject)]
76+
for x in res:
77+
x.cog = cog
78+
if isinstance(x, model.CogCommandObject):
79+
self.commands[x.name] = x
80+
else:
81+
if x.base in self.commands.keys():
82+
self.commands[x.base].allowed_guild_ids += x.allowed_guild_ids
83+
self.commands[x.base].has_subcommands = True
84+
else:
85+
_cmd = {
86+
"func": None,
87+
"description": "No description.",
88+
"auto_convert": {},
89+
"guild_ids": x.allowed_guild_ids,
90+
"api_options": [],
91+
"has_subcommands": True
92+
}
93+
self.commands[x.base] = model.CommandObject(x.base, _cmd)
94+
if x.base not in self.subcommands.keys():
95+
self.subcommands[x.base] = {}
96+
if x.subcommand_group:
97+
if x.subcommand_group not in self.subcommands:
98+
self.subcommands[x.base][x.subcommand_group] = {}
99+
self.subcommands[x.base][x.subcommand_group][x.name] = x
100+
else:
101+
self.subcommands[x.base][x.name] = x
102+
103+
def remove_cog_commands(self, cog):
104+
"""
105+
Removes slash command from :class:`discord.ext.commands.Cog`.
106+
107+
:param cog: Cog that has slash commands.
108+
:type cog: discord.ext.commands.Cog
109+
"""
110+
func_list = [getattr(cog, x) for x in dir(cog)]
111+
res = [x for x in func_list if
112+
isinstance(x, model.CogCommandObject) or isinstance(x, model.CogSubcommandObject)]
113+
for x in res:
114+
if isinstance(x, model.CogCommandObject):
115+
if x.name not in self.commands.keys():
116+
continue # Just in case it is removed due to subcommand.
117+
if x.name in self.subcommands.keys():
118+
self.commands[x.name].func = None
119+
continue # Let's remove completely when every subcommand is removed.
120+
del self.commands[x.name]
121+
else:
122+
if x.base not in self.subcommands.keys():
123+
continue # Just in case...
124+
if x.subcommand_group:
125+
del self.subcommands[x.base][x.subcommand_group][x.name]
126+
if not self.subcommands[x.base][x.subcommand_group]:
127+
del self.subcommands[x.base][x.subcommand_group]
128+
else:
129+
del self.subcommands[x.base][x.name]
130+
if not self.subcommands[x.base]:
131+
del self.subcommands[x.base]
132+
if x.base in self.commands.keys():
133+
if self.commands[x.base].func:
134+
self.commands[x.base].has_subcommands = False
135+
else:
136+
del self.commands[x.base]
137+
61138
async def register_all_commands(self):
62139
"""
63140
Registers all slash commands except subcommands to Discord API.\n
64141
If ``auto_register`` is ``True``, then this will be automatically called.
65142
"""
66-
await self._discord.wait_until_ready() # In case commands are still not registered to SlashCommand.
143+
await self._discord.wait_until_ready() # In case commands are still not registered to SlashCommand.
67144
self.logger.info("Registering commands...")
68145
for x in self.commands.keys():
69146
selected = self.commands[x]
70-
if selected["has_subcommands"] and "func" not in selected.keys():
147+
if selected.has_subcommands and not hasattr(selected, "invoke"):
71148
# Just in case it has subcommands but also has base command.
72149
# More specific, it will skip if it has subcommands and doesn't have base command coroutine.
73150
self.logger.debug("Skipping registering subcommands.")
74151
continue
75-
if selected["guild_ids"]:
76-
for y in selected["guild_ids"]:
152+
if selected.allowed_guild_ids:
153+
for y in selected.allowed_guild_ids:
77154
await manage_commands.add_slash_command(self._discord.user.id,
78155
self._discord.http.token,
79156
y,
80157
x,
81-
selected["description"],
82-
selected["api_options"])
158+
selected.description,
159+
selected.options)
83160
else:
84161
await manage_commands.add_slash_command(self._discord.user.id,
85162
self._discord.http.token,
86163
None,
87164
x,
88-
selected["description"],
89-
selected["api_options"])
165+
selected.description,
166+
selected.options)
90167
self.logger.info("Completed registering all commands!")
91168

92169
def add_slash_command(self,
@@ -129,7 +206,7 @@ def add_slash_command(self,
129206
"api_options": options if options else [],
130207
"has_subcommands": has_subcommands
131208
}
132-
self.commands[name] = _cmd
209+
self.commands[name] = model.CommandObject(name, _cmd)
133210
self.logger.debug(f"Added command `{name}`")
134211

135212
def add_subcommand(self,
@@ -163,7 +240,11 @@ def add_subcommand(self,
163240
name = cmd.__name__ if not name else name
164241
name = name.lower()
165242
_cmd = {
243+
"func": None,
244+
"description": "No description.",
245+
"auto_convert": {},
166246
"guild_ids": guild_ids,
247+
"api_options": [],
167248
"has_subcommands": True
168249
}
169250
_sub = {
@@ -174,18 +255,20 @@ def add_subcommand(self,
174255
"guild_ids": guild_ids,
175256
}
176257
if base not in self.commands.keys():
177-
self.commands[base] = _cmd
258+
self.commands[base] = model.CommandObject(base, _cmd)
178259
else:
179-
self.subcommands[base]["has_subcommands"] = True
260+
self.commands[base].has_subcommands = True
261+
self.commands[base].allowed_guild_ids += guild_ids
180262
if base not in self.subcommands.keys():
181263
self.subcommands[base] = {}
182264
if subcommand_group:
183265
if subcommand_group not in self.subcommands[base].keys():
184266
self.subcommands[base][subcommand_group] = {}
185-
self.subcommands[base][subcommand_group][name] = _sub
267+
self.subcommands[base][subcommand_group][name] = model.SubcommandObject(_sub, base, name, subcommand_group)
186268
else:
187-
self.subcommands[base][name] = _sub
188-
self.logger.debug(f"Added subcommand `{base} {subcommand_group if subcommand_group else ''} {cmd.__name__ if not name else name}`")
269+
self.subcommands[base][name] = model.SubcommandObject(_sub, base, name)
270+
self.logger.debug(
271+
f"Added subcommand `{base} {subcommand_group if subcommand_group else ''} {cmd.__name__ if not name else name}`")
189272

190273
def slash(self,
191274
*,
@@ -259,6 +342,7 @@ async def _pick(ctx, choice1, choice2): # Command with 1 or more args.
259342
def wrapper(cmd):
260343
self.add_slash_command(cmd, name, description, auto_convert, guild_ids, options)
261344
return cmd
345+
262346
return wrapper
263347

264348
def subcommand(self,
@@ -311,6 +395,7 @@ async def _group_kick_user(ctx, user):
311395
def wrapper(cmd):
312396
self.add_subcommand(cmd, base, subcommand_group, name, description, auto_convert, guild_ids)
313397
return cmd
398+
314399
return wrapper
315400

316401
async def process_options(self, guild: discord.Guild, options: list, auto_convert: dict) -> list:
@@ -388,18 +473,21 @@ async def on_socket_response(self, msg):
388473
return
389474
to_use = msg["d"]
390475
if to_use["data"]["name"] in self.commands.keys():
391-
selected_cmd = self.commands[to_use["data"]["name"]]
392476
ctx = model.SlashContext(self.req, to_use, self._discord, self.logger)
393-
if selected_cmd["guild_ids"]:
394-
if ctx.guild.id not in selected_cmd["guild_ids"]:
477+
cmd_name = to_use["data"]["name"]
478+
if cmd_name not in self.commands.keys() and cmd_name in self.subcommands.keys():
479+
return await self.handle_subcommand(ctx, to_use)
480+
selected_cmd = self.commands[to_use["data"]["name"]]
481+
if selected_cmd.allowed_guild_ids:
482+
if ctx.guild.id not in selected_cmd.allowed_guild_ids:
395483
return
396-
if selected_cmd["has_subcommands"]:
484+
if selected_cmd.has_subcommands:
397485
return await self.handle_subcommand(ctx, to_use)
398-
args = await self.process_options(ctx.guild, to_use["data"]["options"], selected_cmd["auto_convert"]) \
486+
args = await self.process_options(ctx.guild, to_use["data"]["options"], selected_cmd.auto_convert) \
399487
if "options" in to_use["data"] else []
400488
self._discord.dispatch("slash_command", ctx)
401489
try:
402-
await selected_cmd["func"](ctx, *args)
490+
await selected_cmd.invoke(ctx, *args)
403491
except Exception as ex:
404492
await self.on_slash_command_error(ctx, ex)
405493

@@ -429,20 +517,20 @@ async def handle_subcommand(self, ctx: model.SlashContext, data: dict):
429517
return
430518
ctx.subcommand_group = sub_group
431519
selected = base[sub_name][sub_group]
432-
args = await self.process_options(ctx.guild, x["options"], selected["auto_convert"]) \
520+
args = await self.process_options(ctx.guild, x["options"], selected.auto_convert) \
433521
if "options" in x.keys() else []
434522
self._discord.dispatch("slash_command", ctx)
435523
try:
436-
await selected["func"](ctx, *args)
524+
await selected.invoke(ctx, *args)
437525
except Exception as ex:
438526
await self.on_slash_command_error(ctx, ex)
439527
return
440528
selected = base[sub_name]
441-
args = await self.process_options(ctx.guild, sub_opts, selected["auto_convert"]) \
529+
args = await self.process_options(ctx.guild, sub_opts, selected.auto_convert) \
442530
if "options" in sub.keys() else []
443531
self._discord.dispatch("slash_command", ctx)
444532
try:
445-
await selected["func"](ctx, *args)
533+
await selected.invoke(ctx, *args)
446534
except Exception as ex:
447535
await self.on_slash_command_error(ctx, ex)
448536

0 commit comments

Comments
 (0)