summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Flint <[email protected]>2018-04-20 07:57:08 -0700
committerGitHub <[email protected]>2018-04-20 07:57:08 -0700
commit074ee5f759999d103724b5594e33901adeb28e73 (patch)
tree150c205cd0c5b285c4ee04482a24f491e1e38342
parentb1eda2c7b64c118d56472c1944054bf9235d6c48 (diff)
parent4d71204936cbe2b4d7ebeb5b8a4d25432599eb17 (diff)
Merge pull request #64 from alexflint/repeated-unmarshaltext
Fix repeated arguments implementing TextUnmarshaler
-rw-r--r--parse.go39
-rw-r--r--parse_test.go26
2 files changed, 44 insertions, 21 deletions
diff --git a/parse.go b/parse.go
index 10c6841..1416223 100644
--- a/parse.go
+++ b/parse.go
@@ -159,7 +159,7 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
}
// Check whether this field is supported. It's good to do this here rather than
- // wait until setScalar because it means that a program with invalid argument
+ // wait until ParseValue because it means that a program with invalid argument
// fields will always fail regardless of whether the arguments it received
// exercised those fields.
var parseable bool
@@ -275,7 +275,7 @@ func process(specs []*spec, args []string) error {
}
if spec.env != "" {
if value, found := os.LookupEnv(spec.env); found {
- err := setScalar(spec.dest, value)
+ err := scalar.ParseValue(spec.dest, value)
if err != nil {
return fmt.Errorf("error processing environment variable %s: %v", spec.env, err)
}
@@ -355,7 +355,7 @@ func process(specs []*spec, args []string) error {
i++
}
- err := setScalar(spec.dest, value)
+ err := scalar.ParseValue(spec.dest, value)
if err != nil {
return fmt.Errorf("error processing %s: %v", arg, err)
}
@@ -374,7 +374,7 @@ func process(specs []*spec, args []string) error {
}
positionals = nil
} else if len(positionals) > 0 {
- err := setScalar(spec.dest, positionals[0])
+ err := scalar.ParseValue(spec.dest, positionals[0])
if err != nil {
return fmt.Errorf("error processing %s: %v", spec.long, err)
}
@@ -426,7 +426,7 @@ func setSlice(dest reflect.Value, values []string, trunc bool) error {
var ptr bool
elem := dest.Type().Elem()
- if elem.Kind() == reflect.Ptr {
+ if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) {
ptr = true
elem = elem.Elem()
}
@@ -438,7 +438,7 @@ func setSlice(dest reflect.Value, values []string, trunc bool) error {
for _, s := range values {
v := reflect.New(elem)
- if err := setScalar(v.Elem(), s); err != nil {
+ if err := scalar.ParseValue(v.Elem(), s); err != nil {
return err
}
if !ptr {
@@ -451,7 +451,8 @@ func setSlice(dest reflect.Value, values []string, trunc bool) error {
// canParse returns true if the type can be parsed from a string
func canParse(t reflect.Type) (parseable, boolean, multiple bool) {
- parseable, boolean = isScalar(t)
+ parseable = scalar.CanParse(t)
+ boolean = isBoolean(t)
if parseable {
return
}
@@ -466,7 +467,8 @@ func canParse(t reflect.Type) (parseable, boolean, multiple bool) {
t = t.Elem()
}
- parseable, boolean = isScalar(t)
+ parseable = scalar.CanParse(t)
+ boolean = isBoolean(t)
if parseable {
return
}
@@ -476,7 +478,8 @@ func canParse(t reflect.Type) (parseable, boolean, multiple bool) {
t = t.Elem()
}
- parseable, boolean = isScalar(t)
+ parseable = scalar.CanParse(t)
+ boolean = isBoolean(t)
if parseable {
return
}
@@ -486,22 +489,16 @@ func canParse(t reflect.Type) (parseable, boolean, multiple bool) {
var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
-// isScalar returns true if the type can be parsed from a single string
-func isScalar(t reflect.Type) (parseable, boolean bool) {
- parseable = scalar.CanParse(t)
+// isBoolean returns true if the type can be parsed from a single string
+func isBoolean(t reflect.Type) bool {
switch {
case t.Implements(textUnmarshalerType):
- return parseable, false
+ return false
case t.Kind() == reflect.Bool:
- return parseable, true
+ return true
case t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Bool:
- return parseable, true
+ return true
default:
- return parseable, false
+ return false
}
}
-
-// set a value from a string
-func setScalar(v reflect.Value, s string) error {
- return scalar.ParseValue(v, s)
-}
diff --git a/parse_test.go b/parse_test.go
index 925a23e..1461c02 100644
--- a/parse_test.go
+++ b/parse_test.go
@@ -599,6 +599,32 @@ func TestTextUnmarshaler(t *testing.T) {
assert.Equal(t, 3, args.Foo.val)
}
+func TestRepeatedTextUnmarshaler(t *testing.T) {
+ // fields that implement TextUnmarshaler should be parsed using that interface
+ var args struct {
+ Foo []*textUnmarshaler
+ }
+ err := parse("--foo abc d ef", &args)
+ require.NoError(t, err)
+ require.Len(t, args.Foo, 3)
+ assert.Equal(t, 3, args.Foo[0].val)
+ assert.Equal(t, 1, args.Foo[1].val)
+ assert.Equal(t, 2, args.Foo[2].val)
+}
+
+func TestPositionalTextUnmarshaler(t *testing.T) {
+ // fields that implement TextUnmarshaler should be parsed using that interface
+ var args struct {
+ Foo []*textUnmarshaler `arg:"positional"`
+ }
+ err := parse("abc d ef", &args)
+ require.NoError(t, err)
+ require.Len(t, args.Foo, 3)
+ assert.Equal(t, 3, args.Foo[0].val)
+ assert.Equal(t, 1, args.Foo[1].val)
+ assert.Equal(t, 2, args.Foo[2].val)
+}
+
type boolUnmarshaler bool
func (p *boolUnmarshaler) UnmarshalText(b []byte) error {