summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Flint <[email protected]>2018-11-19 12:48:00 -0800
committerGitHub <[email protected]>2018-11-19 12:48:00 -0800
commit6ab8ad5e1c5b25ca2783fe83f493c3ab471407e2 (patch)
tree52e1f2fbbc29c8a42a882e67a6830b64cc4ead9d
parente80c3b7ed292b052c7083b6fd7154a8422c33f65 (diff)
parente1338aeff04713f53befe44b42323f55fd60338c (diff)
Merge pull request #2 from pborzenkov/text-unmarshaler-value
Allow to use values (not pointers) with TextUnmarshaler
-rw-r--r--scalar.go19
-rw-r--r--scalar_test.go12
2 files changed, 21 insertions, 10 deletions
diff --git a/scalar.go b/scalar.go
index 663f143..073392c 100644
--- a/scalar.go
+++ b/scalar.go
@@ -18,7 +18,6 @@ var (
textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
durationType = reflect.TypeOf(time.Duration(0))
mailAddressType = reflect.TypeOf(mail.Address{})
- ipType = reflect.TypeOf(net.IP{})
macType = reflect.TypeOf(net.HardwareAddr{})
)
@@ -47,6 +46,13 @@ func ParseValue(v reflect.Value, s string) error {
if scalar, ok := v.Interface().(encoding.TextUnmarshaler); ok {
return scalar.UnmarshalText([]byte(s))
}
+ // If it's a value instead of a pointer, check that we can unmarshal it
+ // via TextUnmarshaler as well
+ if v.CanAddr() {
+ if scalar, ok := v.Addr().Interface().(encoding.TextUnmarshaler); ok {
+ return scalar.UnmarshalText([]byte(s))
+ }
+ }
// If we have a pointer then dereference it
if v.Kind() == reflect.Ptr {
@@ -73,13 +79,6 @@ func ParseValue(v reflect.Value, s string) error {
}
v.Set(reflect.ValueOf(*addr))
return nil
- case net.IP:
- ip := net.ParseIP(s)
- if ip == nil {
- return fmt.Errorf(`invalid IP address: "%s"`, s)
- }
- v.Set(reflect.ValueOf(ip))
- return nil
case net.HardwareAddr:
ip, err := net.ParseMAC(s)
if err != nil {
@@ -126,7 +125,7 @@ func ParseValue(v reflect.Value, s string) error {
// CanParse returns true if the type can be parsed from a string.
func CanParse(t reflect.Type) bool {
// If it implements encoding.TextUnmarshaler then use that
- if t.Implements(textUnmarshalerType) {
+ if t.Implements(textUnmarshalerType) || reflect.PtrTo(t).Implements(textUnmarshalerType) {
return true
}
@@ -137,7 +136,7 @@ func CanParse(t reflect.Type) bool {
// Check for other special types
switch t {
- case durationType, mailAddressType, ipType, macType:
+ case durationType, mailAddressType, macType:
return true
}
diff --git a/scalar_test.go b/scalar_test.go
index d70bd32..9a1ef6a 100644
--- a/scalar_test.go
+++ b/scalar_test.go
@@ -10,6 +10,15 @@ import (
"github.com/stretchr/testify/require"
)
+type textUnmarshaler struct {
+ val int
+}
+
+func (f *textUnmarshaler) UnmarshalText(b []byte) error {
+ f.val = len(b)
+ return nil
+}
+
func assertParse(t *testing.T, expected interface{}, str string) {
v := reflect.New(reflect.TypeOf(expected)).Elem()
err := ParseValue(v, str)
@@ -67,6 +76,9 @@ func TestParseValue(t *testing.T) {
// MAC addresses
assertParse(t, net.HardwareAddr("\x01\x23\x45\x67\x89\xab"), "01:23:45:67:89:ab")
+
+ // custom text unmarshaler
+ assertParse(t, textUnmarshaler{3}, "abc")
}
func TestParse(t *testing.T) {