diff --git a/argparse.py b/argparse.py index 3bdcd33..c271811 100644 --- a/argparse.py +++ b/argparse.py @@ -90,6 +90,9 @@ from gettext import gettext as _ +# XXX.bewest: remove +from pprint import pprint + try: set except NameError: @@ -123,6 +126,7 @@ def _callable(obj): ZERO_OR_MORE = '*' ONE_OR_MORE = '+' PARSER = 'A...' +_OPTIONAL_PARSER = 'A...?' REMAINDER = '...' _UNRECOGNIZED_ARGS_ATTR = '_unrecognized_args' @@ -601,7 +605,7 @@ def _format_args(self, action, default_metavar): result = '%s [%s ...]' % get_metavar(2) elif action.nargs == REMAINDER: result = '...' - elif action.nargs == PARSER: + elif action.nargs in [ PARSER, _OPTIONAL_PARSER ]: result = '%s ...' % get_metavar(1) else: formats = ['%s' for _ in range(action.nargs)] @@ -1055,6 +1059,7 @@ def __init__(self, parser_class, dest=SUPPRESS, help=None, + default=None, metavar=None): self._prog_prefix = prog @@ -1062,10 +1067,15 @@ def __init__(self, self._name_parser_map = {} self._choices_actions = [] + nargs = PARSER + if default is not None: + nargs = _OPTIONAL_PARSER + super(_SubParsersAction, self).__init__( option_strings=option_strings, dest=dest, - nargs=PARSER, + nargs=nargs, + default=default, choices=self._name_parser_map, help=help, metavar=metavar) @@ -2173,6 +2183,11 @@ def _get_nargs_pattern(self, action): elif nargs == PARSER: nargs_pattern = '(-*A[-AO]*)' + # allow one optional argument followed by any number of options or + # arguments + elif nargs == _OPTIONAL_PARSER: + nargs_pattern = '(-*A?[-AO]*)?' + # all others should be integers else: nargs_pattern = '(-*%s-*)' % '-*'.join('A' * nargs) @@ -2190,7 +2205,8 @@ def _get_nargs_pattern(self, action): # ======================== def _get_values(self, action, arg_strings): # for everything but PARSER args, strip out '--' - if action.nargs not in [PARSER, REMAINDER]: + #if action.nargs not in [PARSER, REMAINDER]: + if action.nargs not in [_OPTIONAL_PARSER, PARSER, REMAINDER]: arg_strings = [s for s in arg_strings if s != '--'] # optional argument produces a default when not present @@ -2202,7 +2218,6 @@ def _get_values(self, action, arg_strings): if isinstance(value, basestring): value = self._get_value(action, value) self._check_value(action, value) - # when nargs='*' on a positional, if there were no command-line # args, use the default if it is anything other than None elif (not arg_strings and action.nargs == ZERO_OR_MORE and @@ -2224,8 +2239,11 @@ def _get_values(self, action, arg_strings): value = [self._get_value(action, v) for v in arg_strings] # PARSER arguments convert all values, but check only the first - elif action.nargs == PARSER: + elif action.nargs in [ PARSER, _OPTIONAL_PARSER ]: value = [self._get_value(action, v) for v in arg_strings] + if (not arg_strings and action.default is not None + and action.nargs == _OPTIONAL_PARSER): + value = [action.default] self._check_value(action, value[0]) # all other types of nargs produce a list