From 1dfefdc43e8a9a06b532b5c29f876eb38f86a928 Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Mon, 19 Apr 2021 12:10:53 -0700 Subject: factor setSlice into its own file, add setMap, and add tests for both --- sequence_test.go | 81 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 sequence_test.go (limited to 'sequence_test.go') diff --git a/sequence_test.go b/sequence_test.go new file mode 100644 index 0000000..4646811 --- /dev/null +++ b/sequence_test.go @@ -0,0 +1,81 @@ +package arg + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSetSliceWithoutClearing(t *testing.T) { + xs := []int{10} + entries := []string{"1", "2", "3"} + err := setSlice(reflect.ValueOf(&xs).Elem(), entries, false) + require.NoError(t, err) + assert.Equal(t, []int{10, 1, 2, 3}, xs) +} + +func TestSetSliceWithClear(t *testing.T) { + xs := []int{100} + entries := []string{"1", "2", "3"} + err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + require.NoError(t, err) + assert.Equal(t, []int{1, 2, 3}, xs) +} + +func TestSetSlicePtr(t *testing.T) { + var xs []*int + entries := []string{"1", "2", "3"} + err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, xs, 3) + assert.Equal(t, 1, *xs[0]) + assert.Equal(t, 2, *xs[1]) + assert.Equal(t, 3, *xs[2]) +} + +func TestSetSliceTextUnmarshaller(t *testing.T) { + // textUnmarshaler is a struct that captures the length of the string passed to it + var xs []*textUnmarshaler + entries := []string{"a", "aa", "aaa"} + err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, xs, 3) + assert.Equal(t, 1, xs[0].val) + assert.Equal(t, 2, xs[1].val) + assert.Equal(t, 3, xs[2].val) +} + +func TestSetMapWithoutClearing(t *testing.T) { + m := map[string]int{"foo": 10} + entries := []string{"a=1", "b=2"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, false) + require.NoError(t, err) + require.Len(t, m, 3) + assert.Equal(t, 1, m["a"]) + assert.Equal(t, 2, m["b"]) + assert.Equal(t, 10, m["foo"]) +} + +func TestSetMapWithClear(t *testing.T) { + m := map[string]int{"foo": 10} + entries := []string{"a=1", "b=2"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, m, 2) + assert.Equal(t, 1, m["a"]) + assert.Equal(t, 2, m["b"]) +} + +func TestSetMapTextUnmarshaller(t *testing.T) { + // textUnmarshaler is a struct that captures the length of the string passed to it + var m map[textUnmarshaler]*textUnmarshaler + entries := []string{"a=123", "aa=12", "aaa=1"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, m, 3) + assert.Equal(t, &textUnmarshaler{3}, m[textUnmarshaler{1}]) + assert.Equal(t, &textUnmarshaler{2}, m[textUnmarshaler{2}]) + assert.Equal(t, &textUnmarshaler{1}, m[textUnmarshaler{3}]) +} -- cgit v1.2.3 From 23b96d7aacf62828675decc309eae5b9dce5bd51 Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Mon, 19 Apr 2021 12:49:49 -0700 Subject: refactor canParse into kindOf --- parse.go | 4 +-- parse_test.go | 10 +++++++ reflect.go | 82 ++++++++++++++++++++++++++++++++++++++------------------ reflect_test.go | 66 ++++++++++++++++++++++++++++----------------- sequence_test.go | 8 ++++++ 5 files changed, 118 insertions(+), 52 deletions(-) (limited to 'sequence_test.go') diff --git a/parse.go b/parse.go index 84a7ed1..37df734 100644 --- a/parse.go +++ b/parse.go @@ -377,7 +377,7 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { cmd.specs = append(cmd.specs, &spec) var parseable bool - parseable, spec.boolean, spec.multiple = canParse(field.Type) + //parseable, spec.boolean, spec.multiple = canParse(field.Type) if !parseable { errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported", t.Name(), field.Name, field.Type.String())) @@ -728,7 +728,7 @@ func findSubcommand(cmds []*command, name string) *command { // isZero returns true if v contains the zero value for its type func isZero(v reflect.Value) bool { t := v.Type() - if t.Kind() == reflect.Slice { + if t.Kind() == reflect.Slice || t.Kind() == reflect.Map { return v.IsNil() } if !t.Comparable() { diff --git a/parse_test.go b/parse_test.go index ce3068e..0decfc1 100644 --- a/parse_test.go +++ b/parse_test.go @@ -220,6 +220,16 @@ func TestLongFlag(t *testing.T) { assert.Equal(t, "xyz", args.Foo) } +func TestSliceOfBools(t *testing.T) { + var args struct { + B []bool + } + + err := parse("--b true false true", &args) + require.NoError(t, err) + assert.Equal(t, []bool{true, false, true}, args.B) +} + func TestPlaceholder(t *testing.T) { var args struct { Input string `arg:"positional" placeholder:"SRC"` diff --git a/reflect.go b/reflect.go index f1e8e8d..c4fc5d9 100644 --- a/reflect.go +++ b/reflect.go @@ -2,6 +2,7 @@ package arg import ( "encoding" + "fmt" "reflect" "unicode" "unicode/utf8" @@ -11,42 +12,71 @@ import ( var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem() -// canParse returns true if the type can be parsed from a string -func canParse(t reflect.Type) (parseable, boolean, multiple bool) { - parseable = scalar.CanParse(t) - boolean = isBoolean(t) - if parseable { - return - } +// kind is used to track the various kinds of options: +// - regular is an ordinary option that will be parsed from a single token +// - binary is an option that will be true if present but does not expect an explicit value +// - sequence is an option that accepts multiple values and will end up in a slice +// - mapping is an option that acccepts multiple key=value strings and will end up in a map +type kind int - // Look inside pointer types - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - // Look inside slice types - if t.Kind() == reflect.Slice { - multiple = true - t = t.Elem() +const ( + regular kind = iota + binary + sequence + mapping + unsupported +) + +func (k kind) String() string { + switch k { + case regular: + return "regular" + case binary: + return "binary" + case sequence: + return "sequence" + case mapping: + return "mapping" + case unsupported: + return "unsupported" + default: + return fmt.Sprintf("unknown(%d)", int(k)) } +} - parseable = scalar.CanParse(t) - boolean = isBoolean(t) - if parseable { - return +// kindOf returns true if the type can be parsed from a string +func kindOf(t reflect.Type) (kind, error) { + if scalar.CanParse(t) { + if isBoolean(t) { + return binary, nil + } else { + return regular, nil + } } - // Look inside pointer types (again, in case of []*Type) + // look inside pointer types if t.Kind() == reflect.Ptr { t = t.Elem() } - parseable = scalar.CanParse(t) - boolean = isBoolean(t) - if parseable { - return + // look inside slice and map types + switch t.Kind() { + case reflect.Slice: + if !scalar.CanParse(t.Elem()) { + return unsupported, fmt.Errorf("cannot parse into %v because we cannot parse into %v", t, t.Elem()) + } + return sequence, nil + case reflect.Map: + if !scalar.CanParse(t.Key()) { + return unsupported, fmt.Errorf("cannot parse into %v because we cannot parse into the key type %v", t, t.Elem()) + } + if !scalar.CanParse(t.Elem()) { + return unsupported, fmt.Errorf("cannot parse into %v because we cannot parse into the value type %v", t, t.Elem()) + } + return mapping, nil + default: + return unsupported, fmt.Errorf("cannot parse into %v", t) } - - return false, false, false } // isBoolean returns true if the type can be parsed from a single string diff --git a/reflect_test.go b/reflect_test.go index 07b459c..6a8af49 100644 --- a/reflect_test.go +++ b/reflect_test.go @@ -7,36 +7,51 @@ import ( "github.com/stretchr/testify/assert" ) -func assertCanParse(t *testing.T, typ reflect.Type, parseable, boolean, multiple bool) { - p, b, m := canParse(typ) - assert.Equal(t, parseable, p, "expected %v to have parseable=%v but was %v", typ, parseable, p) - assert.Equal(t, boolean, b, "expected %v to have boolean=%v but was %v", typ, boolean, b) - assert.Equal(t, multiple, m, "expected %v to have multiple=%v but was %v", typ, multiple, m) +func assertKind(t *testing.T, typ reflect.Type, expected kind) { + actual, err := kindOf(typ) + assert.Equal(t, expected, actual, "expected %v to have kind %v but got %v", typ, expected, actual) + if expected == unsupported { + assert.Error(t, err) + } } -func TestCanParse(t *testing.T) { +func TestKindOf(t *testing.T) { var b bool var i int var s string var f float64 var bs []bool var is []int + var m map[string]int + var unsupported1 struct{} + var unsupported2 []struct{} + var unsupported3 map[string]struct{} - assertCanParse(t, reflect.TypeOf(b), true, true, false) - assertCanParse(t, reflect.TypeOf(i), true, false, false) - assertCanParse(t, reflect.TypeOf(s), true, false, false) - assertCanParse(t, reflect.TypeOf(f), true, false, false) + assertKind(t, reflect.TypeOf(b), binary) + assertKind(t, reflect.TypeOf(i), regular) + assertKind(t, reflect.TypeOf(s), regular) + assertKind(t, reflect.TypeOf(f), regular) - assertCanParse(t, reflect.TypeOf(&b), true, true, false) - assertCanParse(t, reflect.TypeOf(&s), true, false, false) - assertCanParse(t, reflect.TypeOf(&i), true, false, false) - assertCanParse(t, reflect.TypeOf(&f), true, false, false) + assertKind(t, reflect.TypeOf(&b), binary) + assertKind(t, reflect.TypeOf(&s), regular) + assertKind(t, reflect.TypeOf(&i), regular) + assertKind(t, reflect.TypeOf(&f), regular) - assertCanParse(t, reflect.TypeOf(bs), true, true, true) - assertCanParse(t, reflect.TypeOf(&bs), true, true, true) + assertKind(t, reflect.TypeOf(bs), sequence) + assertKind(t, reflect.TypeOf(is), sequence) - assertCanParse(t, reflect.TypeOf(is), true, false, true) - assertCanParse(t, reflect.TypeOf(&is), true, false, true) + assertKind(t, reflect.TypeOf(&bs), sequence) + assertKind(t, reflect.TypeOf(&is), sequence) + + assertKind(t, reflect.TypeOf(m), mapping) + assertKind(t, reflect.TypeOf(&m), mapping) + + assertKind(t, reflect.TypeOf(unsupported1), unsupported) + assertKind(t, reflect.TypeOf(&unsupported1), unsupported) + assertKind(t, reflect.TypeOf(unsupported2), unsupported) + assertKind(t, reflect.TypeOf(&unsupported2), unsupported) + assertKind(t, reflect.TypeOf(unsupported3), unsupported) + assertKind(t, reflect.TypeOf(&unsupported3), unsupported) } type implementsTextUnmarshaler struct{} @@ -46,12 +61,15 @@ func (*implementsTextUnmarshaler) UnmarshalText(text []byte) error { } func TestCanParseTextUnmarshaler(t *testing.T) { - var u implementsTextUnmarshaler - var su []implementsTextUnmarshaler - assertCanParse(t, reflect.TypeOf(u), true, false, false) - assertCanParse(t, reflect.TypeOf(&u), true, false, false) - assertCanParse(t, reflect.TypeOf(su), true, false, true) - assertCanParse(t, reflect.TypeOf(&su), true, false, true) + var x implementsTextUnmarshaler + var s []implementsTextUnmarshaler + var m []implementsTextUnmarshaler + assertKind(t, reflect.TypeOf(x), regular) + assertKind(t, reflect.TypeOf(&x), regular) + assertKind(t, reflect.TypeOf(s), sequence) + assertKind(t, reflect.TypeOf(&s), sequence) + assertKind(t, reflect.TypeOf(m), mapping) + assertKind(t, reflect.TypeOf(&m), mapping) } func TestIsExported(t *testing.T) { diff --git a/sequence_test.go b/sequence_test.go index 4646811..446bc42 100644 --- a/sequence_test.go +++ b/sequence_test.go @@ -79,3 +79,11 @@ func TestSetMapTextUnmarshaller(t *testing.T) { assert.Equal(t, &textUnmarshaler{2}, m[textUnmarshaler{2}]) assert.Equal(t, &textUnmarshaler{1}, m[textUnmarshaler{3}]) } + +func TestSetMapMalformed(t *testing.T) { + // textUnmarshaler is a struct that captures the length of the string passed to it + var m map[string]string + entries := []string{"missing_equals_sign"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + assert.Error(t, err) +} -- cgit v1.2.3 From 0100c0a411486a0a26c0d7bb5504c5f371aaf6b0 Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Mon, 19 Apr 2021 13:48:07 -0700 Subject: push code coverage up --- sequence_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) (limited to 'sequence_test.go') diff --git a/sequence_test.go b/sequence_test.go index 446bc42..dd866c6 100644 --- a/sequence_test.go +++ b/sequence_test.go @@ -16,7 +16,7 @@ func TestSetSliceWithoutClearing(t *testing.T) { assert.Equal(t, []int{10, 1, 2, 3}, xs) } -func TestSetSliceWithClear(t *testing.T) { +func TestSetSliceAfterClearing(t *testing.T) { xs := []int{100} entries := []string{"1", "2", "3"} err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) @@ -24,6 +24,13 @@ func TestSetSliceWithClear(t *testing.T) { assert.Equal(t, []int{1, 2, 3}, xs) } +func TestSetSliceInvalid(t *testing.T) { + xs := []int{100} + entries := []string{"invalid"} + err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + assert.Error(t, err) +} + func TestSetSlicePtr(t *testing.T) { var xs []*int entries := []string{"1", "2", "3"} @@ -58,7 +65,7 @@ func TestSetMapWithoutClearing(t *testing.T) { assert.Equal(t, 10, m["foo"]) } -func TestSetMapWithClear(t *testing.T) { +func TestSetMapAfterClearing(t *testing.T) { m := map[string]int{"foo": 10} entries := []string{"a=1", "b=2"} err := setMap(reflect.ValueOf(&m).Elem(), entries, true) @@ -68,6 +75,25 @@ func TestSetMapWithClear(t *testing.T) { assert.Equal(t, 2, m["b"]) } +func TestSetMapWithKeyPointer(t *testing.T) { + // textUnmarshaler is a struct that captures the length of the string passed to it + var m map[*string]int + entries := []string{"abc=123"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, m, 1) +} + +func TestSetMapWithValuePointer(t *testing.T) { + // textUnmarshaler is a struct that captures the length of the string passed to it + var m map[string]*int + entries := []string{"abc=123"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, m, 1) + assert.Equal(t, 123, *m["abc"]) +} + func TestSetMapTextUnmarshaller(t *testing.T) { // textUnmarshaler is a struct that captures the length of the string passed to it var m map[textUnmarshaler]*textUnmarshaler @@ -80,6 +106,20 @@ func TestSetMapTextUnmarshaller(t *testing.T) { assert.Equal(t, &textUnmarshaler{1}, m[textUnmarshaler{3}]) } +func TestSetMapInvalidKey(t *testing.T) { + var m map[int]int + entries := []string{"invalid=123"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + assert.Error(t, err) +} + +func TestSetMapInvalidValue(t *testing.T) { + var m map[int]int + entries := []string{"123=invalid"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + assert.Error(t, err) +} + func TestSetMapMalformed(t *testing.T) { // textUnmarshaler is a struct that captures the length of the string passed to it var m map[string]string -- cgit v1.2.3 From d4b9b2a00813ef6f28f75a685bd868aab4609ec4 Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Mon, 19 Apr 2021 19:23:08 -0700 Subject: push coverage up even more --- sequence_test.go | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) (limited to 'sequence_test.go') diff --git a/sequence_test.go b/sequence_test.go index dd866c6..fde3e3a 100644 --- a/sequence_test.go +++ b/sequence_test.go @@ -127,3 +127,26 @@ func TestSetMapMalformed(t *testing.T) { err := setMap(reflect.ValueOf(&m).Elem(), entries, true) assert.Error(t, err) } + +func TestSetSliceOrMapErrors(t *testing.T) { + var err error + var dest reflect.Value + + // converting a slice to a reflect.Value in this way will make it read only + var cannotSet []int + dest = reflect.ValueOf(cannotSet) + err = setSliceOrMap(dest, nil, false) + assert.Error(t, err) + + // check what happens when we pass in something that is not a slice or a map + var notSliceOrMap string + dest = reflect.ValueOf(¬SliceOrMap).Elem() + err = setSliceOrMap(dest, nil, false) + assert.Error(t, err) + + // check what happens when we pass in a pointer to something that is not a slice or a map + var stringPtr *string + dest = reflect.ValueOf(&stringPtr).Elem() + err = setSliceOrMap(dest, nil, false) + assert.Error(t, err) +} -- cgit v1.2.3