diff options
| author | Pavel Borzenkov <[email protected]> | 2018-11-16 17:09:52 +0300 |
|---|---|---|
| committer | Pavel Borzenkov <[email protected]> | 2018-11-16 17:09:52 +0300 |
| commit | 38f8eb7c6bf0819f9f2e4e980067994d372a7b76 (patch) | |
| tree | 50a325336f572672fdc31603870086840119fd9d | |
| parent | e80c3b7ed292b052c7083b6fd7154a8422c33f65 (diff) | |
Allow to use values (not pointers) with TextUnmarshaler
The patch makes sure that both values and pointer to values are checked
for custom TextUnmarshal implementation. This will allow to use go-arg
custom parsing as follows:
var args struct {
Arg CustomType
}
instead of
var args struct {
Arg *CustomType
}
Signed-off-by: Pavel Borzenkov <[email protected]>
| -rw-r--r-- | scalar.go | 9 | ||||
| -rw-r--r-- | scalar_test.go | 12 |
2 files changed, 20 insertions, 1 deletions
@@ -47,6 +47,13 @@ func ParseValue(v reflect.Value, s string) error { if scalar, ok := v.Interface().(encoding.TextUnmarshaler); ok { return scalar.UnmarshalText([]byte(s)) } + // If it's a value instead of a pointer, check that we can unmarshal it + // via TextUnmarshaler as well + if v.CanAddr() { + if scalar, ok := v.Addr().Interface().(encoding.TextUnmarshaler); ok { + return scalar.UnmarshalText([]byte(s)) + } + } // If we have a pointer then dereference it if v.Kind() == reflect.Ptr { @@ -126,7 +133,7 @@ func ParseValue(v reflect.Value, s string) error { // CanParse returns true if the type can be parsed from a string. func CanParse(t reflect.Type) bool { // If it implements encoding.TextUnmarshaler then use that - if t.Implements(textUnmarshalerType) { + if t.Implements(textUnmarshalerType) || reflect.PtrTo(t).Implements(textUnmarshalerType) { return true } diff --git a/scalar_test.go b/scalar_test.go index d70bd32..9a1ef6a 100644 --- a/scalar_test.go +++ b/scalar_test.go @@ -10,6 +10,15 @@ import ( "github.com/stretchr/testify/require" ) +type textUnmarshaler struct { + val int +} + +func (f *textUnmarshaler) UnmarshalText(b []byte) error { + f.val = len(b) + return nil +} + func assertParse(t *testing.T, expected interface{}, str string) { v := reflect.New(reflect.TypeOf(expected)).Elem() err := ParseValue(v, str) @@ -67,6 +76,9 @@ func TestParseValue(t *testing.T) { // MAC addresses assertParse(t, net.HardwareAddr("\x01\x23\x45\x67\x89\xab"), "01:23:45:67:89:ab") + + // custom text unmarshaler + assertParse(t, textUnmarshaler{3}, "abc") } func TestParse(t *testing.T) { |
