@@ -42,14 +42,12 @@ def get_version_dist(name=__name__):
4242__all__ = ["Optional" , "Required" , "Choice" , "complete" ]
4343log = logging .getLogger (__name__ )
4444
45- CHOICE_FUNCTIONS_BASH = {
46- "file" : "_shtab_compgen_files" ,
47- "directory" : "_shtab_compgen_dirs" ,
48- }
49- CHOICE_FUNCTIONS_ZSH = {
50- "file" : "_files" ,
51- "directory" : "_files -/" ,
45+ CHOICE_FUNCTIONS = {
46+ "file" : {"bash" : "_shtab_compgen_files" , "zsh" : "_files" },
47+ "directory" : {"bash" : "_shtab_compgen_dirs" , "zsh" : "_files -/" },
5248}
49+ FILE = CHOICE_FUNCTIONS ["file" ]
50+ DIRECTORY = DIR = CHOICE_FUNCTIONS ["directory" ]
5351FLAG_OPTION = (
5452 _StoreConstAction ,
5553 _HelpAction ,
@@ -109,25 +107,21 @@ class Required(object):
109107 DIR = DIRECTORY = [Choice ("directory" , True )]
110108
111109
110+ def complete2pattern (opt_complete , shell , choice_type2fn ):
111+ return (
112+ opt_complete .get (shell , "" )
113+ if isinstance (opt_complete , dict )
114+ else choice_type2fn [opt_complete ]
115+ )
116+
117+
112118def replace_format (string , ** fmt ):
113119 """Similar to `string.format(**fmt)` but ignores unknown `{key}`s."""
114120 for k , v in fmt .items ():
115121 string = string .replace ("{" + k + "}" , v )
116122 return string
117123
118124
119- def get_optional_actions (parser ):
120- """Flattened list of all `parser`'s optional actions."""
121- return sum (
122- (
123- opt .option_strings
124- for opt in parser ._get_optional_actions ()
125- if opt .help != SUPPRESS
126- ),
127- [],
128- )
129-
130-
131125def get_bash_commands (root_parser , root_prefix , choice_functions = None ):
132126 """
133127 Recursive subcommand parser traversal, printing bash helper syntax.
@@ -145,13 +139,24 @@ def get_bash_commands(root_parser, root_prefix, choice_functions=None):
145139 # `add_argument('subcommand', choices=shtab.Required.FILE)`)
146140 _{root_parser.prog}_{subcommand}_COMPGEN=_shtab_compgen_files
147141 """
148- choice_type2fn = dict ( CHOICE_FUNCTIONS_BASH )
142+ choice_type2fn = { k : v [ "bash" ] for k , v in CHOICE_FUNCTIONS . items ()}
149143 if choice_functions :
150144 choice_type2fn .update (choice_functions )
151145
152146 fd = io .StringIO ()
153147 root_options = []
154148
149+ def get_optional_actions (parser ):
150+ """Flattened list of all `parser`'s optional actions."""
151+ return sum (
152+ (
153+ opt .option_strings
154+ for opt in parser ._get_optional_actions ()
155+ if opt .help != SUPPRESS
156+ ),
157+ [],
158+ )
159+
155160 def recurse (parser , prefix ):
156161 positionals = parser ._get_positional_actions ()
157162 commands = []
@@ -176,6 +181,16 @@ def recurse(parser, prefix):
176181 for sub in positionals :
177182 if sub .choices :
178183 log .debug ("choices:{}:{}" .format (prefix , sorted (sub .choices )))
184+ if hasattr (sub , "complete" ):
185+ print (
186+ "{}_COMPGEN={}" .format (
187+ prefix ,
188+ complete2pattern (
189+ sub .complete , "bash" , choice_type2fn
190+ ),
191+ ),
192+ file = fd ,
193+ )
179194 for cmd in sorted (sub .choices ):
180195 if isinstance (cmd , Choice ):
181196 log .debug (
@@ -342,7 +357,7 @@ def complete_zsh(parser, root_prefix=None, preamble="", choice_functions=None):
342357 root_arguments = []
343358 subcommands = {} # {cmd: {"help": help, "arguments": [arguments]}}
344359
345- choice_type2fn = dict ( CHOICE_FUNCTIONS_ZSH )
360+ choice_type2fn = { k : v [ "zsh" ] for k , v in CHOICE_FUNCTIONS . items ()}
346361 if choice_functions :
347362 choice_type2fn .update (choice_functions )
348363
@@ -368,7 +383,9 @@ def format_optional(opt):
368383 ),
369384 help = escape_zsh (opt .help or "" ),
370385 dest = opt .dest ,
371- pattern = (
386+ pattern = complete2pattern (opt .complete , "zsh" , choice_type2fn )
387+ if hasattr (opt , "complete" )
388+ else (
372389 choice_type2fn [opt .choices [0 ].type ]
373390 if isinstance (opt .choices [0 ], Choice )
374391 else "({})" .format (" " .join (opt .choices ))
@@ -380,10 +397,12 @@ def format_optional(opt):
380397 )
381398
382399 def format_positional (opt ):
383- return '"{nargs}:{help}:{choices }"' .format (
400+ return '"{nargs}:{help}:{pattern }"' .format (
384401 nargs = {"+" : "*" , "*" : "*" }.get (opt .nargs , "" ),
385402 help = escape_zsh ((opt .help or opt .dest ).strip ().split ("\n " )[0 ]),
386- choices = (
403+ pattern = complete2pattern (opt .complete , "zsh" , choice_type2fn )
404+ if hasattr (opt , "complete" )
405+ else (
387406 choice_type2fn [opt .choices [0 ].type ]
388407 if isinstance (opt .choices [0 ], Choice )
389408 else "({})" .format (" " .join (opt .choices ))
@@ -529,10 +548,15 @@ def complete(
529548 shell : str (bash/zsh)
530549 root_prefix : str, prefix for shell functions to avoid clashes
531550 (default: "_{parser.prog}")
532- preamble : str, prepended to generated script
533- choice_functions : dict, maps custom `shtab.Choice.type`s to
534- completion functions (possibly defined in `preamble`)
551+ preamble : dict, mapping shell to text to prepend to generated script
552+ (e.g. `{"bash": "_myprog_custom_function(){ echo hello }"}`)
553+ choice_functions : deprecated.
554+
555+ N.B. `parser.add_argument().complete = ...` can be used to define custom
556+ completions (e.g. filenames). See <../examples/pathcomplete.py>.
535557 """
558+ if isinstance (preamble , dict ):
559+ preamble = preamble .get (shell , "" )
536560 if shell == "bash" :
537561 return complete_bash (
538562 parser ,
0 commit comments