Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 68 additions & 35 deletions flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ type FlagSet struct {
normalizeNameFunc func(f *FlagSet, name string) NormalizedName

addedGoFlagSets []*goflag.FlagSet
unknownFlags []*Flag
}

// A Flag represents the state of a flag.
Expand All @@ -182,6 +183,12 @@ type Flag struct {
Annotations map[string][]string // used by cobra.Command bash autocomple code
}

// A UnknownFlag represents the state of a flag that is not expected.
type UnknownFlag struct {
Name string // name as it appears on command line
Value Value // value as set
}

// Value is the interface to the dynamic value stored in a flag.
// (The default value is represented as a string.)
type Value interface {
Expand Down Expand Up @@ -275,6 +282,17 @@ func (f *FlagSet) SetOutput(output io.Writer) {
f.output = output
}

// VisitUnknowns visits all the flags that have not been registered.
func (f *FlagSet) VisitUnknowns(fn func(*Flag)) {
if len(f.unknownFlags) == 0 {
return
}

for _, flag := range f.unknownFlags {
fn(flag)
}
}

// VisitAll visits the flags in lexicographical order or
// in primordial order if f.SortFlags is false, calling fn for each.
// It visits all flags, even those not set.
Expand Down Expand Up @@ -956,6 +974,18 @@ func stripUnknownFlagValue(args []string) []string {
return nil
}

func createUnknownFlag(name string, value string) *Flag {
flag := new(Flag)
flag.Name = name
flag.Value = newStringValue(value, &value)
return flag
}

func (f *FlagSet) addUnknownFlag(name string, value string) {
flag := createUnknownFlag(name, value)
f.unknownFlags = append(f.unknownFlags, flag)
}

func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []string, err error) {
a = args
name := s[2:]
Expand All @@ -969,19 +999,11 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin
flag, exists := f.formal[f.normalizeFlagName(name)]

if !exists {
switch {
case name == "help":
if name == "help" {
f.usage()
return a, ErrHelp
case f.ParseErrorsWhitelist.UnknownFlags:
// --unknown=unknownval arg ...
// we do not want to lose arg in this case
if len(split) >= 2 {
return a, nil
}

return stripUnknownFlagValue(a), nil
default:
}
if !f.ParseErrorsWhitelist.UnknownFlags {
err = f.failf("unknown flag: --%s", name)
return
}
Expand All @@ -991,16 +1013,23 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin
if len(split) == 2 {
// '--flag=arg'
value = split[1]
} else if flag.NoOptDefVal != "" {
} else if exists && flag.NoOptDefVal != "" {
// '--flag' (arg was optional)
value = flag.NoOptDefVal
} else if len(a) > 0 {
// '--flag arg'
value = a[0]
a = a[1:]
} else {
// '--flag' (arg was required)
err = f.failf("flag needs an argument: %s", s)
if !exists && strings.HasPrefix(a[0], "-") {
value = ""
} else {
value = a[0]
a = a[1:]
}
} else if f.ParseErrorsWhitelist.UnknownFlags {
value = ""
}

if !exists && f.ParseErrorsWhitelist.UnknownFlags {
f.addUnknownFlag(name, value)
return
}

Expand All @@ -1023,22 +1052,12 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse

flag, exists := f.shorthands[c]
if !exists {
switch {
case c == 'h':
if c == 'h' {
f.usage()
err = ErrHelp
return
case f.ParseErrorsWhitelist.UnknownFlags:
// '-f=arg arg ...'
// we do not want to lose arg in this case
if len(shorthands) > 2 && shorthands[1] == '=' {
outShorts = ""
return
}

outArgs = stripUnknownFlagValue(outArgs)
return
default:
}
if !f.ParseErrorsWhitelist.UnknownFlags {
err = f.failf("unknown shorthand flag: %q in -%s", c, shorthands)
return
}
Expand All @@ -1049,18 +1068,32 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse
// '-f=arg'
value = shorthands[2:]
outShorts = ""
} else if flag.NoOptDefVal != "" {
// '-f' (arg was optional)
} else if exists && flag.NoOptDefVal != "" {
// '--flag' (arg was optional)
value = flag.NoOptDefVal
} else if len(shorthands) > 1 {
// '-farg'
value = shorthands[1:]
outShorts = ""
} else if len(args) > 0 {
// '-f arg'
value = args[0]
outArgs = args[1:]
} else {
if !exists && strings.HasPrefix(args[0], "-") {
value = ""
} else {
value = args[0]
outArgs = args[1:]
}

} else if f.ParseErrorsWhitelist.UnknownFlags {
value = ""
}

if !exists && f.ParseErrorsWhitelist.UnknownFlags {
f.addUnknownFlag(string(c), value)
return
}

if flag.NoOptDefVal == "" && value == "" {
// '-f' (arg was required)
err = f.failf("flag needs an argument: %q in -%s", c, shorthands)
return
Expand Down
63 changes: 63 additions & 0 deletions flag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,68 @@ func testParseWithUnknownFlags(f *FlagSet, t *testing.T) {
}
}

func testRetrieveUknowsWhenUnknownFlagsParsed(t *testing.T) {
f := NewFlagSet("unknwonFlags", ContinueOnError)
if f.Parsed() {
t.Error("f.Parse() = true before Parse")
}
boolaFlag := f.BoolP("boola", "a", false, "bool value")
stringaFlag := f.StringP("stringa", "s", "0", "string value")

args := []string{
"-a",
"--stringa",
"hello",
"--unknownFlag1",
"unknownValue1",
"--unknownFlag2",
"--unknownFlag3=unknownValue3",
"-e",
"unknownValue4",
"-f=unknownValue5",
"-g",
}

f.ParseErrorsWhitelist.UnknownFlags = true

want := map[string]string{
"unknownFlag1": "unknownValue1",
"unknownFlag2": "",
"unknownFlag3": "unknownValue3",
"e": "unknownValue4",
"f": "unknownValue5",
"g": "",
}

f.SetOutput(ioutil.Discard)
if err := f.Parse(args); err != nil {
t.Error("expected no error, got ", err)
}
if !f.Parsed() {
t.Error("f.Parse() = false after Parse")
}
if *boolaFlag != true {
t.Error("boola flag should be true, is ", *boolaFlag)
}
if *stringaFlag != "hello" {
t.Error("stringa flag should be `hello`, is ", *stringaFlag)
}
if len(f.unknownFlags) != len(want) {
t.Errorf("f.ParseAll() failed to parse unknown flags")
}
for _, flag := range f.unknownFlags {
wantedValue, ok := want[flag.Name]
if !ok {
t.Errorf("f.unknownFlags contains a flag \"%s\" and shouldn't", flag.Name)
break
}
if wantedValue != flag.Value.String() {
t.Errorf("value for the unknown flag \"%s\" should be \"%s\", got \"%s\"", flag.Name, wantedValue, flag.Value.String())
}

}
}

func TestShorthand(t *testing.T) {
f := NewFlagSet("shorthand", ContinueOnError)
if f.Parsed() {
Expand Down Expand Up @@ -588,6 +650,7 @@ func TestParseAll(t *testing.T) {

func TestIgnoreUnknownFlags(t *testing.T) {
ResetForTesting(func() { t.Error("bad parse") })
testRetrieveUknowsWhenUnknownFlagsParsed(t)
testParseWithUnknownFlags(GetCommandLine(), t)
}

Expand Down