summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPavel Borzenkov <[email protected]>2018-11-16 17:09:52 +0300
committerPavel Borzenkov <[email protected]>2018-11-16 17:09:52 +0300
commit38f8eb7c6bf0819f9f2e4e980067994d372a7b76 (patch)
tree50a325336f572672fdc31603870086840119fd9d
parente80c3b7ed292b052c7083b6fd7154a8422c33f65 (diff)
Allow to use values (not pointers) with TextUnmarshaler
The patch makes sure that both values and pointer to values are checked for custom TextUnmarshal implementation. This will allow to use go-arg custom parsing as follows: var args struct { Arg CustomType } instead of var args struct { Arg *CustomType } Signed-off-by: Pavel Borzenkov <[email protected]>
-rw-r--r--scalar.go9
-rw-r--r--scalar_test.go12
2 files changed, 20 insertions, 1 deletions
diff --git a/scalar.go b/scalar.go
index 663f143..e4525b3 100644
--- a/scalar.go
+++ b/scalar.go
@@ -47,6 +47,13 @@ func ParseValue(v reflect.Value, s string) error {
if scalar, ok := v.Interface().(encoding.TextUnmarshaler); ok {
return scalar.UnmarshalText([]byte(s))
}
+ // If it's a value instead of a pointer, check that we can unmarshal it
+ // via TextUnmarshaler as well
+ if v.CanAddr() {
+ if scalar, ok := v.Addr().Interface().(encoding.TextUnmarshaler); ok {
+ return scalar.UnmarshalText([]byte(s))
+ }
+ }
// If we have a pointer then dereference it
if v.Kind() == reflect.Ptr {
@@ -126,7 +133,7 @@ func ParseValue(v reflect.Value, s string) error {
// CanParse returns true if the type can be parsed from a string.
func CanParse(t reflect.Type) bool {
// If it implements encoding.TextUnmarshaler then use that
- if t.Implements(textUnmarshalerType) {
+ if t.Implements(textUnmarshalerType) || reflect.PtrTo(t).Implements(textUnmarshalerType) {
return true
}
diff --git a/scalar_test.go b/scalar_test.go
index d70bd32..9a1ef6a 100644
--- a/scalar_test.go
+++ b/scalar_test.go
@@ -10,6 +10,15 @@ import (
"github.com/stretchr/testify/require"
)
+type textUnmarshaler struct {
+ val int
+}
+
+func (f *textUnmarshaler) UnmarshalText(b []byte) error {
+ f.val = len(b)
+ return nil
+}
+
func assertParse(t *testing.T, expected interface{}, str string) {
v := reflect.New(reflect.TypeOf(expected)).Elem()
err := ParseValue(v, str)
@@ -67,6 +76,9 @@ func TestParseValue(t *testing.T) {
// MAC addresses
assertParse(t, net.HardwareAddr("\x01\x23\x45\x67\x89\xab"), "01:23:45:67:89:ab")
+
+ // custom text unmarshaler
+ assertParse(t, textUnmarshaler{3}, "abc")
}
func TestParse(t *testing.T) {