summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--parse.go4
-rw-r--r--parse_test.go10
-rw-r--r--reflect.go82
-rw-r--r--reflect_test.go66
-rw-r--r--sequence_test.go8
5 files changed, 118 insertions, 52 deletions
diff --git a/parse.go b/parse.go
index 84a7ed1..37df734 100644
--- a/parse.go
+++ b/parse.go
@@ -377,7 +377,7 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
cmd.specs = append(cmd.specs, &spec)
var parseable bool
- parseable, spec.boolean, spec.multiple = canParse(field.Type)
+ //parseable, spec.boolean, spec.multiple = canParse(field.Type)
if !parseable {
errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported",
t.Name(), field.Name, field.Type.String()))
@@ -728,7 +728,7 @@ func findSubcommand(cmds []*command, name string) *command {
// isZero returns true if v contains the zero value for its type
func isZero(v reflect.Value) bool {
t := v.Type()
- if t.Kind() == reflect.Slice {
+ if t.Kind() == reflect.Slice || t.Kind() == reflect.Map {
return v.IsNil()
}
if !t.Comparable() {
diff --git a/parse_test.go b/parse_test.go
index ce3068e..0decfc1 100644
--- a/parse_test.go
+++ b/parse_test.go
@@ -220,6 +220,16 @@ func TestLongFlag(t *testing.T) {
assert.Equal(t, "xyz", args.Foo)
}
+func TestSliceOfBools(t *testing.T) {
+ var args struct {
+ B []bool
+ }
+
+ err := parse("--b true false true", &args)
+ require.NoError(t, err)
+ assert.Equal(t, []bool{true, false, true}, args.B)
+}
+
func TestPlaceholder(t *testing.T) {
var args struct {
Input string `arg:"positional" placeholder:"SRC"`
diff --git a/reflect.go b/reflect.go
index f1e8e8d..c4fc5d9 100644
--- a/reflect.go
+++ b/reflect.go
@@ -2,6 +2,7 @@ package arg
import (
"encoding"
+ "fmt"
"reflect"
"unicode"
"unicode/utf8"
@@ -11,42 +12,71 @@ import (
var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
-// canParse returns true if the type can be parsed from a string
-func canParse(t reflect.Type) (parseable, boolean, multiple bool) {
- parseable = scalar.CanParse(t)
- boolean = isBoolean(t)
- if parseable {
- return
- }
+// kind is used to track the various kinds of options:
+// - regular is an ordinary option that will be parsed from a single token
+// - binary is an option that will be true if present but does not expect an explicit value
+// - sequence is an option that accepts multiple values and will end up in a slice
+// - mapping is an option that acccepts multiple key=value strings and will end up in a map
+type kind int
- // Look inside pointer types
- if t.Kind() == reflect.Ptr {
- t = t.Elem()
- }
- // Look inside slice types
- if t.Kind() == reflect.Slice {
- multiple = true
- t = t.Elem()
+const (
+ regular kind = iota
+ binary
+ sequence
+ mapping
+ unsupported
+)
+
+func (k kind) String() string {
+ switch k {
+ case regular:
+ return "regular"
+ case binary:
+ return "binary"
+ case sequence:
+ return "sequence"
+ case mapping:
+ return "mapping"
+ case unsupported:
+ return "unsupported"
+ default:
+ return fmt.Sprintf("unknown(%d)", int(k))
}
+}
- parseable = scalar.CanParse(t)
- boolean = isBoolean(t)
- if parseable {
- return
+// kindOf returns true if the type can be parsed from a string
+func kindOf(t reflect.Type) (kind, error) {
+ if scalar.CanParse(t) {
+ if isBoolean(t) {
+ return binary, nil
+ } else {
+ return regular, nil
+ }
}
- // Look inside pointer types (again, in case of []*Type)
+ // look inside pointer types
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
- parseable = scalar.CanParse(t)
- boolean = isBoolean(t)
- if parseable {
- return
+ // look inside slice and map types
+ switch t.Kind() {
+ case reflect.Slice:
+ if !scalar.CanParse(t.Elem()) {
+ return unsupported, fmt.Errorf("cannot parse into %v because we cannot parse into %v", t, t.Elem())
+ }
+ return sequence, nil
+ case reflect.Map:
+ if !scalar.CanParse(t.Key()) {
+ return unsupported, fmt.Errorf("cannot parse into %v because we cannot parse into the key type %v", t, t.Elem())
+ }
+ if !scalar.CanParse(t.Elem()) {
+ return unsupported, fmt.Errorf("cannot parse into %v because we cannot parse into the value type %v", t, t.Elem())
+ }
+ return mapping, nil
+ default:
+ return unsupported, fmt.Errorf("cannot parse into %v", t)
}
-
- return false, false, false
}
// isBoolean returns true if the type can be parsed from a single string
diff --git a/reflect_test.go b/reflect_test.go
index 07b459c..6a8af49 100644
--- a/reflect_test.go
+++ b/reflect_test.go
@@ -7,36 +7,51 @@ import (
"github.com/stretchr/testify/assert"
)
-func assertCanParse(t *testing.T, typ reflect.Type, parseable, boolean, multiple bool) {
- p, b, m := canParse(typ)
- assert.Equal(t, parseable, p, "expected %v to have parseable=%v but was %v", typ, parseable, p)
- assert.Equal(t, boolean, b, "expected %v to have boolean=%v but was %v", typ, boolean, b)
- assert.Equal(t, multiple, m, "expected %v to have multiple=%v but was %v", typ, multiple, m)
+func assertKind(t *testing.T, typ reflect.Type, expected kind) {
+ actual, err := kindOf(typ)
+ assert.Equal(t, expected, actual, "expected %v to have kind %v but got %v", typ, expected, actual)
+ if expected == unsupported {
+ assert.Error(t, err)
+ }
}
-func TestCanParse(t *testing.T) {
+func TestKindOf(t *testing.T) {
var b bool
var i int
var s string
var f float64
var bs []bool
var is []int
+ var m map[string]int
+ var unsupported1 struct{}
+ var unsupported2 []struct{}
+ var unsupported3 map[string]struct{}
- assertCanParse(t, reflect.TypeOf(b), true, true, false)
- assertCanParse(t, reflect.TypeOf(i), true, false, false)
- assertCanParse(t, reflect.TypeOf(s), true, false, false)
- assertCanParse(t, reflect.TypeOf(f), true, false, false)
+ assertKind(t, reflect.TypeOf(b), binary)
+ assertKind(t, reflect.TypeOf(i), regular)
+ assertKind(t, reflect.TypeOf(s), regular)
+ assertKind(t, reflect.TypeOf(f), regular)
- assertCanParse(t, reflect.TypeOf(&b), true, true, false)
- assertCanParse(t, reflect.TypeOf(&s), true, false, false)
- assertCanParse(t, reflect.TypeOf(&i), true, false, false)
- assertCanParse(t, reflect.TypeOf(&f), true, false, false)
+ assertKind(t, reflect.TypeOf(&b), binary)
+ assertKind(t, reflect.TypeOf(&s), regular)
+ assertKind(t, reflect.TypeOf(&i), regular)
+ assertKind(t, reflect.TypeOf(&f), regular)
- assertCanParse(t, reflect.TypeOf(bs), true, true, true)
- assertCanParse(t, reflect.TypeOf(&bs), true, true, true)
+ assertKind(t, reflect.TypeOf(bs), sequence)
+ assertKind(t, reflect.TypeOf(is), sequence)
- assertCanParse(t, reflect.TypeOf(is), true, false, true)
- assertCanParse(t, reflect.TypeOf(&is), true, false, true)
+ assertKind(t, reflect.TypeOf(&bs), sequence)
+ assertKind(t, reflect.TypeOf(&is), sequence)
+
+ assertKind(t, reflect.TypeOf(m), mapping)
+ assertKind(t, reflect.TypeOf(&m), mapping)
+
+ assertKind(t, reflect.TypeOf(unsupported1), unsupported)
+ assertKind(t, reflect.TypeOf(&unsupported1), unsupported)
+ assertKind(t, reflect.TypeOf(unsupported2), unsupported)
+ assertKind(t, reflect.TypeOf(&unsupported2), unsupported)
+ assertKind(t, reflect.TypeOf(unsupported3), unsupported)
+ assertKind(t, reflect.TypeOf(&unsupported3), unsupported)
}
type implementsTextUnmarshaler struct{}
@@ -46,12 +61,15 @@ func (*implementsTextUnmarshaler) UnmarshalText(text []byte) error {
}
func TestCanParseTextUnmarshaler(t *testing.T) {
- var u implementsTextUnmarshaler
- var su []implementsTextUnmarshaler
- assertCanParse(t, reflect.TypeOf(u), true, false, false)
- assertCanParse(t, reflect.TypeOf(&u), true, false, false)
- assertCanParse(t, reflect.TypeOf(su), true, false, true)
- assertCanParse(t, reflect.TypeOf(&su), true, false, true)
+ var x implementsTextUnmarshaler
+ var s []implementsTextUnmarshaler
+ var m []implementsTextUnmarshaler
+ assertKind(t, reflect.TypeOf(x), regular)
+ assertKind(t, reflect.TypeOf(&x), regular)
+ assertKind(t, reflect.TypeOf(s), sequence)
+ assertKind(t, reflect.TypeOf(&s), sequence)
+ assertKind(t, reflect.TypeOf(m), mapping)
+ assertKind(t, reflect.TypeOf(&m), mapping)
}
func TestIsExported(t *testing.T) {
diff --git a/sequence_test.go b/sequence_test.go
index 4646811..446bc42 100644
--- a/sequence_test.go
+++ b/sequence_test.go
@@ -79,3 +79,11 @@ func TestSetMapTextUnmarshaller(t *testing.T) {
assert.Equal(t, &textUnmarshaler{2}, m[textUnmarshaler{2}])
assert.Equal(t, &textUnmarshaler{1}, m[textUnmarshaler{3}])
}
+
+func TestSetMapMalformed(t *testing.T) {
+ // textUnmarshaler is a struct that captures the length of the string passed to it
+ var m map[string]string
+ entries := []string{"missing_equals_sign"}
+ err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
+ assert.Error(t, err)
+}