summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md15
-rw-r--r--example_test.go13
-rw-r--r--parse.go74
-rw-r--r--parse_test.go67
-rw-r--r--reflect.go78
-rw-r--r--reflect_test.go79
-rw-r--r--sequence.go123
-rw-r--r--sequence_test.go152
-rw-r--r--usage.go18
9 files changed, 505 insertions, 114 deletions
diff --git a/README.md b/README.md
index da69469..48fa2f0 100644
--- a/README.md
+++ b/README.md
@@ -191,6 +191,7 @@ var args struct {
Files []string `arg:"-f,separate"`
Databases []string `arg:"positional"`
}
+arg.MustParse(&args)
```
```shell
@@ -200,6 +201,20 @@ Files [file1 file2 file3]
Databases [db1 db2 db3]
```
+### Arguments with keys and values
+```go
+var args struct {
+ UserIDs map[string]int
+}
+arg.MustParse(&args)
+fmt.Println(args.UserIDs)
+```
+
+```shell
+./example --userids john=123 mary=456
+map[john:123 mary:456]
+```
+
### Custom validation
```go
var args struct {
diff --git a/example_test.go b/example_test.go
index 9091151..5645156 100644
--- a/example_test.go
+++ b/example_test.go
@@ -82,6 +82,19 @@ func Example_multipleValues() {
// output: Fetching the following IDs from localhost: [1 2 3]
}
+// This example demonstrates arguments with keys and values
+func Example_mappings() {
+ // The args you would pass in on the command line
+ os.Args = split("./example --userids john=123 mary=456")
+
+ var args struct {
+ UserIDs map[string]int
+ }
+ MustParse(&args)
+ fmt.Println(args.UserIDs)
+ // output: map[john:123 mary:456]
+}
+
// This eample demonstrates multiple value arguments that can be mixed with
// other arguments.
func Example_multipleMixed() {
diff --git a/parse.go b/parse.go
index b7d159d..d357d5c 100644
--- a/parse.go
+++ b/parse.go
@@ -50,15 +50,14 @@ type spec struct {
field reflect.StructField // the struct field from which this option was created
long string // the --long form for this option, or empty if none
short string // the -s short form for this option, or empty if none
- multiple bool
- required bool
- positional bool
- separate bool
- help string
- env string
- boolean bool
- defaultVal string // default value for this option
- placeholder string // name of the data in help
+ cardinality cardinality // determines how many tokens will be present (possible values: zero, one, multiple)
+ required bool // if true, this option must be present on the command line
+ positional bool // if true, this option will be looked for in the positional flags
+ separate bool // if true, each slice and map entry will have its own --flag
+ help string // the help text for this option
+ env string // the name of the environment variable for this option, or empty for none
+ defaultVal string // default value for this option
+ placeholder string // name of the data in help
}
// command represents a named subcommand, or the top-level command
@@ -376,15 +375,15 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
if !isSubcommand {
cmd.specs = append(cmd.specs, &spec)
- var parseable bool
- parseable, spec.boolean, spec.multiple = canParse(field.Type)
- if !parseable {
+ var err error
+ spec.cardinality, err = cardinalityOf(field.Type)
+ if err != nil {
errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported",
t.Name(), field.Name, field.Type.String()))
return false
}
- if spec.multiple && hasDefault {
- errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice fields",
+ if spec.cardinality == multiple && hasDefault {
+ errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice or map fields",
t.Name(), field.Name))
return false
}
@@ -442,7 +441,7 @@ func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error
continue
}
- if spec.multiple {
+ if spec.cardinality == multiple {
// expect a CSV string in an environment
// variable in the case of multiple values
values, err := csv.NewReader(strings.NewReader(value)).Read()
@@ -453,7 +452,7 @@ func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error
err,
)
}
- if err = setSlice(p.val(spec.dest), values, !spec.separate); err != nil {
+ if err = setSliceOrMap(p.val(spec.dest), values, !spec.separate); err != nil {
return fmt.Errorf(
"error processing environment variable %s with multiple values: %v",
spec.env,
@@ -563,7 +562,7 @@ func (p *Parser) process(args []string) error {
wasPresent[spec] = true
// deal with the case of multiple values
- if spec.multiple {
+ if spec.cardinality == multiple {
var values []string
if value == "" {
for i+1 < len(args) && !isFlag(args[i+1]) && args[i+1] != "--" {
@@ -576,7 +575,7 @@ func (p *Parser) process(args []string) error {
} else {
values = append(values, value)
}
- err := setSlice(p.val(spec.dest), values, !spec.separate)
+ err := setSliceOrMap(p.val(spec.dest), values, !spec.separate)
if err != nil {
return fmt.Errorf("error processing %s: %v", arg, err)
}
@@ -585,7 +584,7 @@ func (p *Parser) process(args []string) error {
// if it's a flag and it has no value then set the value to true
// use boolean because this takes account of TextUnmarshaler
- if spec.boolean && value == "" {
+ if spec.cardinality == zero && value == "" {
value = "true"
}
@@ -616,8 +615,8 @@ func (p *Parser) process(args []string) error {
break
}
wasPresent[spec] = true
- if spec.multiple {
- err := setSlice(p.val(spec.dest), positionals, true)
+ if spec.cardinality == multiple {
+ err := setSliceOrMap(p.val(spec.dest), positionals, true)
if err != nil {
return fmt.Errorf("error processing %s: %v", spec.field.Name, err)
}
@@ -702,37 +701,6 @@ func (p *Parser) val(dest path) reflect.Value {
return v
}
-// parse a value as the appropriate type and store it in the struct
-func setSlice(dest reflect.Value, values []string, trunc bool) error {
- if !dest.CanSet() {
- return fmt.Errorf("field is not writable")
- }
-
- var ptr bool
- elem := dest.Type().Elem()
- if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) {
- ptr = true
- elem = elem.Elem()
- }
-
- // Truncate the dest slice in case default values exist
- if trunc && !dest.IsNil() {
- dest.SetLen(0)
- }
-
- for _, s := range values {
- v := reflect.New(elem)
- if err := scalar.ParseValue(v.Elem(), s); err != nil {
- return err
- }
- if !ptr {
- v = v.Elem()
- }
- dest.Set(reflect.Append(dest, v))
- }
- return nil
-}
-
// findOption finds an option from its name, or returns null if no spec is found
func findOption(specs []*spec, name string) *spec {
for _, spec := range specs {
@@ -759,7 +727,7 @@ func findSubcommand(cmds []*command, name string) *command {
// isZero returns true if v contains the zero value for its type
func isZero(v reflect.Value) bool {
t := v.Type()
- if t.Kind() == reflect.Slice {
+ if t.Kind() == reflect.Slice || t.Kind() == reflect.Map {
return v.IsNil()
}
if !t.Comparable() {
diff --git a/parse_test.go b/parse_test.go
index ce3068e..d03cbfd 100644
--- a/parse_test.go
+++ b/parse_test.go
@@ -220,6 +220,60 @@ func TestLongFlag(t *testing.T) {
assert.Equal(t, "xyz", args.Foo)
}
+func TestSlice(t *testing.T) {
+ var args struct {
+ Strings []string
+ }
+ err := parse("--strings a b c", &args)
+ require.NoError(t, err)
+ assert.Equal(t, []string{"a", "b", "c"}, args.Strings)
+}
+func TestSliceOfBools(t *testing.T) {
+ var args struct {
+ B []bool
+ }
+
+ err := parse("--b true false true", &args)
+ require.NoError(t, err)
+ assert.Equal(t, []bool{true, false, true}, args.B)
+}
+
+func TestMap(t *testing.T) {
+ var args struct {
+ Values map[string]int
+ }
+ err := parse("--values a=1 b=2 c=3", &args)
+ require.NoError(t, err)
+ assert.Len(t, args.Values, 3)
+ assert.Equal(t, 1, args.Values["a"])
+ assert.Equal(t, 2, args.Values["b"])
+ assert.Equal(t, 3, args.Values["c"])
+}
+
+func TestMapPositional(t *testing.T) {
+ var args struct {
+ Values map[string]int `arg:"positional"`
+ }
+ err := parse("a=1 b=2 c=3", &args)
+ require.NoError(t, err)
+ assert.Len(t, args.Values, 3)
+ assert.Equal(t, 1, args.Values["a"])
+ assert.Equal(t, 2, args.Values["b"])
+ assert.Equal(t, 3, args.Values["c"])
+}
+
+func TestMapWithSeparate(t *testing.T) {
+ var args struct {
+ Values map[string]int `arg:"separate"`
+ }
+ err := parse("--values a=1 --values b=2 --values c=3", &args)
+ require.NoError(t, err)
+ assert.Len(t, args.Values, 3)
+ assert.Equal(t, 1, args.Values["a"])
+ assert.Equal(t, 2, args.Values["b"])
+ assert.Equal(t, 3, args.Values["c"])
+}
+
func TestPlaceholder(t *testing.T) {
var args struct {
Input string `arg:"positional" placeholder:"SRC"`
@@ -688,6 +742,17 @@ func TestEnvironmentVariableSliceArgumentWrongType(t *testing.T) {
assert.Error(t, err)
}
+func TestEnvironmentVariableMap(t *testing.T) {
+ var args struct {
+ Foo map[int]string `arg:"env"`
+ }
+ setenv(t, "FOO", "1=one,99=ninetynine")
+ MustParse(&args)
+ assert.Len(t, args.Foo, 2)
+ assert.Equal(t, "one", args.Foo[1])
+ assert.Equal(t, "ninetynine", args.Foo[99])
+}
+
func TestEnvironmentVariableIgnored(t *testing.T) {
var args struct {
Foo string `arg:"env"`
@@ -1223,7 +1288,7 @@ func TestDefaultValuesNotAllowedWithSlice(t *testing.T) {
}
err := parse("", &args)
- assert.EqualError(t, err, ".A: default values are not supported for slice fields")
+ assert.EqualError(t, err, ".A: default values are not supported for slice or map fields")
}
func TestUnexportedFieldsSkipped(t *testing.T) {
diff --git a/reflect.go b/reflect.go
index f1e8e8d..1806973 100644
--- a/reflect.go
+++ b/reflect.go
@@ -2,6 +2,7 @@ package arg
import (
"encoding"
+ "fmt"
"reflect"
"unicode"
"unicode/utf8"
@@ -11,42 +12,67 @@ import (
var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
-// canParse returns true if the type can be parsed from a string
-func canParse(t reflect.Type) (parseable, boolean, multiple bool) {
- parseable = scalar.CanParse(t)
- boolean = isBoolean(t)
- if parseable {
- return
- }
+// cardinality tracks how many tokens are expected for a given spec
+// - zero is a boolean, which does to expect any value
+// - one is an ordinary option that will be parsed from a single token
+// - multiple is a slice or map that can accept zero or more tokens
+type cardinality int
- // Look inside pointer types
- if t.Kind() == reflect.Ptr {
- t = t.Elem()
- }
- // Look inside slice types
- if t.Kind() == reflect.Slice {
- multiple = true
- t = t.Elem()
+const (
+ zero cardinality = iota
+ one
+ multiple
+ unsupported
+)
+
+func (k cardinality) String() string {
+ switch k {
+ case zero:
+ return "zero"
+ case one:
+ return "one"
+ case multiple:
+ return "multiple"
+ case unsupported:
+ return "unsupported"
+ default:
+ return fmt.Sprintf("unknown(%d)", int(k))
}
+}
- parseable = scalar.CanParse(t)
- boolean = isBoolean(t)
- if parseable {
- return
+// cardinalityOf returns true if the type can be parsed from a string
+func cardinalityOf(t reflect.Type) (cardinality, error) {
+ if scalar.CanParse(t) {
+ if isBoolean(t) {
+ return zero, nil
+ } else {
+ return one, nil
+ }
}
- // Look inside pointer types (again, in case of []*Type)
+ // look inside pointer types
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
- parseable = scalar.CanParse(t)
- boolean = isBoolean(t)
- if parseable {
- return
+ // look inside slice and map types
+ switch t.Kind() {
+ case reflect.Slice:
+ if !scalar.CanParse(t.Elem()) {
+ return unsupported, fmt.Errorf("cannot parse into %v because %v not supported", t, t.Elem())
+ }
+ return multiple, nil
+ case reflect.Map:
+ if !scalar.CanParse(t.Key()) {
+ return unsupported, fmt.Errorf("cannot parse into %v because key type %v not supported", t, t.Elem())
+ }
+ if !scalar.CanParse(t.Elem()) {
+ return unsupported, fmt.Errorf("cannot parse into %v because value type %v not supported", t, t.Elem())
+ }
+ return multiple, nil
+ default:
+ return unsupported, fmt.Errorf("cannot parse into %v", t)
}
-
- return false, false, false
}
// isBoolean returns true if the type can be parsed from a single string
diff --git a/reflect_test.go b/reflect_test.go
index 07b459c..8d65fd9 100644
--- a/reflect_test.go
+++ b/reflect_test.go
@@ -7,36 +7,54 @@ import (
"github.com/stretchr/testify/assert"
)
-func assertCanParse(t *testing.T, typ reflect.Type, parseable, boolean, multiple bool) {
- p, b, m := canParse(typ)
- assert.Equal(t, parseable, p, "expected %v to have parseable=%v but was %v", typ, parseable, p)
- assert.Equal(t, boolean, b, "expected %v to have boolean=%v but was %v", typ, boolean, b)
- assert.Equal(t, multiple, m, "expected %v to have multiple=%v but was %v", typ, multiple, m)
+func assertCardinality(t *testing.T, typ reflect.Type, expected cardinality) {
+ actual, err := cardinalityOf(typ)
+ assert.Equal(t, expected, actual, "expected %v to have cardinality %v but got %v", typ, expected, actual)
+ if expected == unsupported {
+ assert.Error(t, err)
+ }
}
-func TestCanParse(t *testing.T) {
+func TestCardinalityOf(t *testing.T) {
var b bool
var i int
var s string
var f float64
var bs []bool
var is []int
+ var m map[string]int
+ var unsupported1 struct{}
+ var unsupported2 []struct{}
+ var unsupported3 map[string]struct{}
+ var unsupported4 map[struct{}]string
- assertCanParse(t, reflect.TypeOf(b), true, true, false)
- assertCanParse(t, reflect.TypeOf(i), true, false, false)
- assertCanParse(t, reflect.TypeOf(s), true, false, false)
- assertCanParse(t, reflect.TypeOf(f), true, false, false)
+ assertCardinality(t, reflect.TypeOf(b), zero)
+ assertCardinality(t, reflect.TypeOf(i), one)
+ assertCardinality(t, reflect.TypeOf(s), one)
+ assertCardinality(t, reflect.TypeOf(f), one)
- assertCanParse(t, reflect.TypeOf(&b), true, true, false)
- assertCanParse(t, reflect.TypeOf(&s), true, false, false)
- assertCanParse(t, reflect.TypeOf(&i), true, false, false)
- assertCanParse(t, reflect.TypeOf(&f), true, false, false)
+ assertCardinality(t, reflect.TypeOf(&b), zero)
+ assertCardinality(t, reflect.TypeOf(&s), one)
+ assertCardinality(t, reflect.TypeOf(&i), one)
+ assertCardinality(t, reflect.TypeOf(&f), one)
- assertCanParse(t, reflect.TypeOf(bs), true, true, true)
- assertCanParse(t, reflect.TypeOf(&bs), true, true, true)
+ assertCardinality(t, reflect.TypeOf(bs), multiple)
+ assertCardinality(t, reflect.TypeOf(is), multiple)
- assertCanParse(t, reflect.TypeOf(is), true, false, true)
- assertCanParse(t, reflect.TypeOf(&is), true, false, true)
+ assertCardinality(t, reflect.TypeOf(&bs), multiple)
+ assertCardinality(t, reflect.TypeOf(&is), multiple)
+
+ assertCardinality(t, reflect.TypeOf(m), multiple)
+ assertCardinality(t, reflect.TypeOf(&m), multiple)
+
+ assertCardinality(t, reflect.TypeOf(unsupported1), unsupported)
+ assertCardinality(t, reflect.TypeOf(&unsupported1), unsupported)
+ assertCardinality(t, reflect.TypeOf(unsupported2), unsupported)
+ assertCardinality(t, reflect.TypeOf(&unsupported2), unsupported)
+ assertCardinality(t, reflect.TypeOf(unsupported3), unsupported)
+ assertCardinality(t, reflect.TypeOf(&unsupported3), unsupported)
+ assertCardinality(t, reflect.TypeOf(unsupported4), unsupported)
+ assertCardinality(t, reflect.TypeOf(&unsupported4), unsupported)
}
type implementsTextUnmarshaler struct{}
@@ -45,13 +63,16 @@ func (*implementsTextUnmarshaler) UnmarshalText(text []byte) error {
return nil
}
-func TestCanParseTextUnmarshaler(t *testing.T) {
- var u implementsTextUnmarshaler
- var su []implementsTextUnmarshaler
- assertCanParse(t, reflect.TypeOf(u), true, false, false)
- assertCanParse(t, reflect.TypeOf(&u), true, false, false)
- assertCanParse(t, reflect.TypeOf(su), true, false, true)
- assertCanParse(t, reflect.TypeOf(&su), true, false, true)
+func TestCardinalityTextUnmarshaler(t *testing.T) {
+ var x implementsTextUnmarshaler
+ var s []implementsTextUnmarshaler
+ var m []implementsTextUnmarshaler
+ assertCardinality(t, reflect.TypeOf(x), one)
+ assertCardinality(t, reflect.TypeOf(&x), one)
+ assertCardinality(t, reflect.TypeOf(s), multiple)
+ assertCardinality(t, reflect.TypeOf(&s), multiple)
+ assertCardinality(t, reflect.TypeOf(m), multiple)
+ assertCardinality(t, reflect.TypeOf(&m), multiple)
}
func TestIsExported(t *testing.T) {
@@ -60,3 +81,11 @@ func TestIsExported(t *testing.T) {
assert.False(t, isExported(""))
assert.False(t, isExported(string([]byte{255})))
}
+
+func TestCardinalityString(t *testing.T) {
+ assert.Equal(t, "zero", zero.String())
+ assert.Equal(t, "one", one.String())
+ assert.Equal(t, "multiple", multiple.String())
+ assert.Equal(t, "unsupported", unsupported.String())
+ assert.Equal(t, "unknown(42)", cardinality(42).String())
+}
diff --git a/sequence.go b/sequence.go
new file mode 100644
index 0000000..35a3614
--- /dev/null
+++ b/sequence.go
@@ -0,0 +1,123 @@
+package arg
+
+import (
+ "fmt"
+ "reflect"
+ "strings"
+
+ scalar "github.com/alexflint/go-scalar"
+)
+
+// setSliceOrMap parses a sequence of strings into a slice or map. If clear is
+// true then any values already in the slice or map are first removed.
+func setSliceOrMap(dest reflect.Value, values []string, clear bool) error {
+ if !dest.CanSet() {
+ return fmt.Errorf("field is not writable")
+ }
+
+ t := dest.Type()
+ if t.Kind() == reflect.Ptr {
+ dest = dest.Elem()
+ t = t.Elem()
+ }
+
+ switch t.Kind() {
+ case reflect.Slice:
+ return setSlice(dest, values, clear)
+ case reflect.Map:
+ return setMap(dest, values, clear)
+ default:
+ return fmt.Errorf("setSliceOrMap cannot insert values into a %v", t)
+ }
+}
+
+// setSlice parses a sequence of strings and inserts them into a slice. If clear
+// is true then any values already in the slice are removed.
+func setSlice(dest reflect.Value, values []string, clear bool) error {
+ var ptr bool
+ elem := dest.Type().Elem()
+ if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) {
+ ptr = true
+ elem = elem.Elem()
+ }
+
+ // clear the slice in case default values exist
+ if clear && !dest.IsNil() {
+ dest.SetLen(0)
+ }
+
+ // parse the values one-by-one
+ for _, s := range values {
+ v := reflect.New(elem)
+ if err := scalar.ParseValue(v.Elem(), s); err != nil {
+ return err
+ }
+ if !ptr {
+ v = v.Elem()
+ }
+ dest.Set(reflect.Append(dest, v))
+ }
+ return nil
+}
+
+// setMap parses a sequence of name=value strings and inserts them into a map.
+// If clear is true then any values already in the map are removed.
+func setMap(dest reflect.Value, values []string, clear bool) error {
+ // determine the key and value type
+ var keyIsPtr bool
+ keyType := dest.Type().Key()
+ if keyType.Kind() == reflect.Ptr && !keyType.Implements(textUnmarshalerType) {
+ keyIsPtr = true
+ keyType = keyType.Elem()
+ }
+
+ var valIsPtr bool
+ valType := dest.Type().Elem()
+ if valType.Kind() == reflect.Ptr && !valType.Implements(textUnmarshalerType) {
+ valIsPtr = true
+ valType = valType.Elem()
+ }
+
+ // clear the slice in case default values exist
+ if clear && !dest.IsNil() {
+ for _, k := range dest.MapKeys() {
+ dest.SetMapIndex(k, reflect.Value{})
+ }
+ }
+
+ // allocate the map if it is not allocated
+ if dest.IsNil() {
+ dest.Set(reflect.MakeMap(dest.Type()))
+ }
+
+ // parse the values one-by-one
+ for _, s := range values {
+ // split at the first equals sign
+ pos := strings.Index(s, "=")
+ if pos == -1 {
+ return fmt.Errorf("cannot parse %q into a map, expected format key=value", s)
+ }
+
+ // parse the key
+ k := reflect.New(keyType)
+ if err := scalar.ParseValue(k.Elem(), s[:pos]); err != nil {
+ return err
+ }
+ if !keyIsPtr {
+ k = k.Elem()
+ }
+
+ // parse the value
+ v := reflect.New(valType)
+ if err := scalar.ParseValue(v.Elem(), s[pos+1:]); err != nil {
+ return err
+ }
+ if !valIsPtr {
+ v = v.Elem()
+ }
+
+ // add it to the map
+ dest.SetMapIndex(k, v)
+ }
+ return nil
+}
diff --git a/sequence_test.go b/sequence_test.go
new file mode 100644
index 0000000..fde3e3a
--- /dev/null
+++ b/sequence_test.go
@@ -0,0 +1,152 @@
+package arg
+
+import (
+ "reflect"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestSetSliceWithoutClearing(t *testing.T) {
+ xs := []int{10}
+ entries := []string{"1", "2", "3"}
+ err := setSlice(reflect.ValueOf(&xs).Elem(), entries, false)
+ require.NoError(t, err)
+ assert.Equal(t, []int{10, 1, 2, 3}, xs)
+}
+
+func TestSetSliceAfterClearing(t *testing.T) {
+ xs := []int{100}
+ entries := []string{"1", "2", "3"}
+ err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true)
+ require.NoError(t, err)
+ assert.Equal(t, []int{1, 2, 3}, xs)
+}
+
+func TestSetSliceInvalid(t *testing.T) {
+ xs := []int{100}
+ entries := []string{"invalid"}
+ err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true)
+ assert.Error(t, err)
+}
+
+func TestSetSlicePtr(t *testing.T) {
+ var xs []*int
+ entries := []string{"1", "2", "3"}
+ err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true)
+ require.NoError(t, err)
+ require.Len(t, xs, 3)
+ assert.Equal(t, 1, *xs[0])
+ assert.Equal(t, 2, *xs[1])
+ assert.Equal(t, 3, *xs[2])
+}
+
+func TestSetSliceTextUnmarshaller(t *testing.T) {
+ // textUnmarshaler is a struct that captures the length of the string passed to it
+ var xs []*textUnmarshaler
+ entries := []string{"a", "aa", "aaa"}
+ err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true)
+ require.NoError(t, err)
+ require.Len(t, xs, 3)
+ assert.Equal(t, 1, xs[0].val)
+ assert.Equal(t, 2, xs[1].val)
+ assert.Equal(t, 3, xs[2].val)
+}
+
+func TestSetMapWithoutClearing(t *testing.T) {
+ m := map[string]int{"foo": 10}
+ entries := []string{"a=1", "b=2"}
+ err := setMap(reflect.ValueOf(&m).Elem(), entries, false)
+ require.NoError(t, err)
+ require.Len(t, m, 3)
+ assert.Equal(t, 1, m["a"])
+ assert.Equal(t, 2, m["b"])
+ assert.Equal(t, 10, m["foo"])
+}
+
+func TestSetMapAfterClearing(t *testing.T) {
+ m := map[string]int{"foo": 10}
+ entries := []string{"a=1", "b=2"}
+ err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
+ require.NoError(t, err)
+ require.Len(t, m, 2)
+ assert.Equal(t, 1, m["a"])
+ assert.Equal(t, 2, m["b"])
+}
+
+func TestSetMapWithKeyPointer(t *testing.T) {
+ // textUnmarshaler is a struct that captures the length of the string passed to it
+ var m map[*string]int
+ entries := []string{"abc=123"}
+ err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
+ require.NoError(t, err)
+ require.Len(t, m, 1)
+}
+
+func TestSetMapWithValuePointer(t *testing.T) {
+ // textUnmarshaler is a struct that captures the length of the string passed to it
+ var m map[string]*int
+ entries := []string{"abc=123"}
+ err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
+ require.NoError(t, err)
+ require.Len(t, m, 1)
+ assert.Equal(t, 123, *m["abc"])
+}
+
+func TestSetMapTextUnmarshaller(t *testing.T) {
+ // textUnmarshaler is a struct that captures the length of the string passed to it
+ var m map[textUnmarshaler]*textUnmarshaler
+ entries := []string{"a=123", "aa=12", "aaa=1"}
+ err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
+ require.NoError(t, err)
+ require.Len(t, m, 3)
+ assert.Equal(t, &textUnmarshaler{3}, m[textUnmarshaler{1}])
+ assert.Equal(t, &textUnmarshaler{2}, m[textUnmarshaler{2}])
+ assert.Equal(t, &textUnmarshaler{1}, m[textUnmarshaler{3}])
+}
+
+func TestSetMapInvalidKey(t *testing.T) {
+ var m map[int]int
+ entries := []string{"invalid=123"}
+ err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
+ assert.Error(t, err)
+}
+
+func TestSetMapInvalidValue(t *testing.T) {
+ var m map[int]int
+ entries := []string{"123=invalid"}
+ err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
+ assert.Error(t, err)
+}
+
+func TestSetMapMalformed(t *testing.T) {
+ // textUnmarshaler is a struct that captures the length of the string passed to it
+ var m map[string]string
+ entries := []string{"missing_equals_sign"}
+ err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
+ assert.Error(t, err)
+}
+
+func TestSetSliceOrMapErrors(t *testing.T) {
+ var err error
+ var dest reflect.Value
+
+ // converting a slice to a reflect.Value in this way will make it read only
+ var cannotSet []int
+ dest = reflect.ValueOf(cannotSet)
+ err = setSliceOrMap(dest, nil, false)
+ assert.Error(t, err)
+
+ // check what happens when we pass in something that is not a slice or a map
+ var notSliceOrMap string
+ dest = reflect.ValueOf(&notSliceOrMap).Elem()
+ err = setSliceOrMap(dest, nil, false)
+ assert.Error(t, err)
+
+ // check what happens when we pass in a pointer to something that is not a slice or a map
+ var stringPtr *string
+ dest = reflect.ValueOf(&stringPtr).Elem()
+ err = setSliceOrMap(dest, nil, false)
+ assert.Error(t, err)
+}
diff --git a/usage.go b/usage.go
index cbbb021..231476b 100644
--- a/usage.go
+++ b/usage.go
@@ -95,7 +95,7 @@ func (p *Parser) writeUsageForCommand(w io.Writer, cmd *command) {
for _, spec := range positionals {
// prefix with a space
fmt.Fprint(w, " ")
- if spec.multiple {
+ if spec.cardinality == multiple {
if !spec.required {
fmt.Fprint(w, "[")
}
@@ -213,16 +213,16 @@ func (p *Parser) writeHelpForCommand(w io.Writer, cmd *command) {
// write the list of built in options
p.printOption(w, &spec{
- boolean: true,
- long: "help",
- short: "h",
- help: "display this help and exit",
+ cardinality: zero,
+ long: "help",
+ short: "h",
+ help: "display this help and exit",
})
if p.version != "" {
p.printOption(w, &spec{
- boolean: true,
- long: "version",
- help: "display version and exit",
+ cardinality: zero,
+ long: "version",
+ help: "display version and exit",
})
}
@@ -249,7 +249,7 @@ func (p *Parser) printOption(w io.Writer, spec *spec) {
}
func synopsis(spec *spec, form string) string {
- if spec.boolean {
+ if spec.cardinality == zero {
return form
}
return form + " " + spec.placeholder