summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--parse.go17
-rw-r--r--parse_test.go35
2 files changed, 45 insertions, 7 deletions
diff --git a/parse.go b/parse.go
index db15e5a..a82e377 100644
--- a/parse.go
+++ b/parse.go
@@ -1,6 +1,7 @@
package arg
import (
+ "encoding"
"errors"
"fmt"
"os"
@@ -445,9 +446,21 @@ func canParse(t reflect.Type) (parseable, boolean, multiple bool) {
return false, false, false
}
+var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
+
// isScalar returns true if the type can be parsed from a single string
-func isScalar(t reflect.Type) (bool, bool) {
- return scalar.CanParse(t), t.Kind() == reflect.Bool
+func isScalar(t reflect.Type) (parseable, boolean bool) {
+ parseable = scalar.CanParse(t)
+ switch {
+ case t.Implements(textUnmarshalerType):
+ return parseable, false
+ case t.Kind() == reflect.Bool:
+ return parseable, true
+ case t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Bool:
+ return parseable, true
+ default:
+ return parseable, false
+ }
}
// set a value from a string
diff --git a/parse_test.go b/parse_test.go
index 8779b6f..5e88700 100644
--- a/parse_test.go
+++ b/parse_test.go
@@ -33,46 +33,71 @@ func parse(cmdline string, dest interface{}) error {
func TestString(t *testing.T) {
var args struct {
Foo string
+ Ptr *string
}
- err := parse("--foo bar", &args)
+ err := parse("--foo bar --ptr baz", &args)
require.NoError(t, err)
assert.Equal(t, "bar", args.Foo)
+ assert.Equal(t, "baz", *args.Ptr)
+}
+
+func TestBool(t *testing.T) {
+ var args struct {
+ A bool
+ B bool
+ C *bool
+ D *bool
+ }
+ err := parse("--a --c", &args)
+ require.NoError(t, err)
+ assert.True(t, args.A)
+ assert.False(t, args.B)
+ assert.True(t, *args.C)
+ assert.Nil(t, args.D)
}
func TestInt(t *testing.T) {
var args struct {
Foo int
+ Ptr *int
}
- err := parse("--foo 7", &args)
+ err := parse("--foo 7 --ptr 8", &args)
require.NoError(t, err)
assert.EqualValues(t, 7, args.Foo)
+ assert.EqualValues(t, 8, *args.Ptr)
}
func TestUint(t *testing.T) {
var args struct {
Foo uint
+ Ptr *uint
}
- err := parse("--foo 7", &args)
+ err := parse("--foo 7 --ptr 8", &args)
require.NoError(t, err)
assert.EqualValues(t, 7, args.Foo)
+ assert.EqualValues(t, 8, *args.Ptr)
}
func TestFloat(t *testing.T) {
var args struct {
Foo float32
+ Ptr *float32
}
- err := parse("--foo 3.4", &args)
+ err := parse("--foo 3.4 --ptr 3.5", &args)
require.NoError(t, err)
assert.EqualValues(t, 3.4, args.Foo)
+ assert.EqualValues(t, 3.5, *args.Ptr)
}
func TestDuration(t *testing.T) {
var args struct {
Foo time.Duration
+ Ptr *time.Duration
}
- err := parse("--foo 3ms", &args)
+ err := parse("--foo 3ms --ptr 4ms", &args)
require.NoError(t, err)
assert.Equal(t, 3*time.Millisecond, args.Foo)
+ assert.Equal(t, 4*time.Millisecond, *args.Ptr)
}
func TestInvalidDuration(t *testing.T) {