diff --git a/flag.go b/flag.go index e9ca46e2..afe8d119 100644 --- a/flag.go +++ b/flag.go @@ -515,6 +515,25 @@ func (f *FlagSet) Set(name, value string) error { return nil } +// MustSetAnnotation sets arbitrary annotations on a flag in the FlagSet. It +// is similar to [FlagSet.SetAnnotation], but panics if the given flag does +// not exist or if the given annotation is already set. +func (f *FlagSet) MustSetAnnotation(name, key string, values []string) { + normalName := f.normalizeFlagName(name) + flag, ok := f.formal[normalName] + if !ok { + panic(&NotExistError{name: name, messageType: flagNoSuchFlagMessage}) + } + if flag.Annotations == nil { + flag.Annotations = map[string][]string{} + } + if _, ok := flag.Annotations[key]; ok { + panic(fmt.Errorf("annotation %q is already set for flag %q", key, name)) + } + + flag.Annotations[key] = values +} + // SetAnnotation allows one to set arbitrary annotations on a flag in the FlagSet. // This is sometimes used by spf13/cobra programs which want to generate additional // bash completion information. diff --git a/flag_test.go b/flag_test.go index c60e344b..dc3f8fd3 100644 --- a/flag_test.go +++ b/flag_test.go @@ -175,6 +175,86 @@ func TestAnnotation(t *testing.T) { } } +func TestMustSetAnnotation(t *testing.T) { + tests := []struct { + doc string + run func(f *FlagSet) + expErr string + }{ + { + doc: "missing flag", + run: func(f *FlagSet) { + f.MustSetAnnotation("missing-flag", "key", nil) + }, + expErr: "no such flag -missing-flag", + }, + { + doc: "set nil annotation", + run: func(f *FlagSet) { + f.MustSetAnnotation("stringa", "key", nil) + if got := f.Lookup("stringa").Annotations["key"]; got != nil { + t.Fatalf("unexpected annotation: %v", got) + } + }, + }, + { + doc: "set non-nil annotation", + run: func(f *FlagSet) { + f.MustSetAnnotation("stringb", "key", []string{"value1"}) + if got := f.Lookup("stringb").Annotations["key"]; !reflect.DeepEqual(got, []string{"value1"}) { + t.Fatalf("unexpected annotation: %v", got) + } + }, + }, + { + doc: "panic when annotation already set", + run: func(f *FlagSet) { + f.MustSetAnnotation("stringc", "key", []string{"value2"}) + }, + expErr: `annotation "key" is already set for flag "stringc"`, + }, + } + + f := NewFlagSet("shorthand", ContinueOnError) + f.StringP("stringa", "a", "", "string value") + f.StringP("stringb", "b", "", "string2 value") + f.StringP("stringc", "c", "", "string3 value") + + if err := f.SetAnnotation("stringc", "key", []string{"value1"}); err != nil { + t.Fatal(err) + } + + for _, tc := range tests { + t.Run(tc.doc, func(t *testing.T) { + defer func() { + r := recover() + + if tc.expErr == "" { + if r != nil { + t.Fatalf("unexpected panic: %v", r) + } + return + } + + if r == nil { + t.Fatalf("expected panic %q, got none", tc.expErr) + } + + err, ok := r.(error) + if !ok { + t.Fatalf("panic value is not error: %T", r) + } + + if err.Error() != tc.expErr { + t.Fatalf("expected panic error %q, got %q", tc.expErr, err.Error()) + } + }() + + tc.run(f) + }) + } +} + func TestName(t *testing.T) { flagSetName := "bob" f := NewFlagSet(flagSetName, ContinueOnError)