summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Flint <[email protected]>2022-06-09 11:21:29 -0400
committerAlex Flint <[email protected]>2022-06-09 11:21:29 -0400
commit23b2b67fe299b63a072a3541f34d57757d0b8df0 (patch)
treeb5abb6cece5d2829bb134cf7e7c0d58216595035
parentf0f44b65d1179ccedb4c56f493f97ec569a6654e (diff)
fix issue #184
-rw-r--r--go.mod8
-rw-r--r--parse.go43
-rw-r--r--parse_test.go66
-rw-r--r--reflect.go19
-rw-r--r--usage_test.go3
5 files changed, 123 insertions, 16 deletions
diff --git a/go.mod b/go.mod
index 67ac880..0823012 100644
--- a/go.mod
+++ b/go.mod
@@ -5,4 +5,10 @@ require (
github.com/stretchr/testify v1.7.0
)
-go 1.13
+require (
+ github.com/davecgh/go-spew v1.1.1 // indirect
+ github.com/pmezard/go-difflib v1.0.0 // indirect
+ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
+)
+
+go 1.18
diff --git a/parse.go b/parse.go
index 7588dfb..28ed014 100644
--- a/parse.go
+++ b/parse.go
@@ -208,18 +208,41 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
return nil, err
}
- // add nonzero field values as defaults
+ // for backwards compatibility, add nonzero field values as defaults
for _, spec := range cmd.specs {
- if v := p.val(spec.dest); v.IsValid() && !isZero(v) {
- if defaultVal, ok := v.Interface().(encoding.TextMarshaler); ok {
- str, err := defaultVal.MarshalText()
- if err != nil {
- return nil, fmt.Errorf("%v: error marshaling default value to string: %v", spec.dest, err)
- }
- spec.defaultVal = string(str)
- } else {
- spec.defaultVal = fmt.Sprintf("%v", v)
+ // do not read default when UnmarshalText is implemented but not MarshalText
+ if isTextUnmarshaler(spec.field.Type) && !isTextMarshaler(spec.field.Type) {
+ continue
+ }
+
+ // do not process types that require multiple values
+ cardinality, _ := cardinalityOf(spec.field.Type)
+ if cardinality != one {
+ continue
+ }
+
+ // get the value
+ v := p.val(spec.dest)
+ if !v.IsValid() {
+ continue
+ }
+
+ // if MarshalText is implemented then use that
+ if m, ok := v.Interface().(encoding.TextMarshaler); ok {
+ if v.IsNil() {
+ continue
+ }
+ s, err := m.MarshalText()
+ if err != nil {
+ return nil, fmt.Errorf("%v: error marshaling default value to string: %v", spec.dest, err)
}
+ spec.defaultVal = string(s)
+ continue
+ }
+
+ // finally, use the value as a default if it is non-zero
+ if !isZero(v) {
+ spec.defaultVal = fmt.Sprintf("%v", v)
}
}
diff --git a/parse_test.go b/parse_test.go
index 2d0ef7a..0d58598 100644
--- a/parse_test.go
+++ b/parse_test.go
@@ -2,6 +2,7 @@ package arg
import (
"bytes"
+ "encoding/json"
"fmt"
"net"
"net/mail"
@@ -1456,3 +1457,68 @@ func TestMustParsePrintsVersion(t *testing.T) {
assert.Equal(t, 0, *exitCode)
assert.Equal(t, "example 3.2.1\n", b.String())
}
+
+type jsonMap struct {
+ val map[string]string
+}
+
+func (v *jsonMap) UnmarshalText(data []byte) error {
+ return json.Unmarshal(data, &v.val)
+}
+
+func TestTextUnmarshallerEmpty(t *testing.T) {
+ // based on https://github.com/alexflint/go-arg/issues/184
+ var args struct {
+ Config jsonMap `arg:"--config"`
+ }
+
+ err := parse("", &args)
+ require.NoError(t, err)
+ assert.Empty(t, args.Config)
+}
+
+func TestTextUnmarshallerEmptyPointer(t *testing.T) {
+ // a slight variant on https://github.com/alexflint/go-arg/issues/184
+ var args struct {
+ Config *jsonMap `arg:"--config"`
+ }
+
+ err := parse("", &args)
+ require.NoError(t, err)
+ assert.Nil(t, args.Config)
+}
+
+// similar to the above but also implements MarshalText
+type jsonMap2[T any] struct {
+ val T
+}
+
+func (v *jsonMap2[T]) MarshalText(data []byte) error {
+ return json.Unmarshal(data, &v.val)
+}
+
+func (v *jsonMap2[T]) UnmarshalText(data []byte) error {
+ return json.Unmarshal(data, &v.val)
+}
+
+func TestTextMarshallerUnmarshallerEmpty(t *testing.T) {
+ // based on https://github.com/alexflint/go-arg/issues/184
+ var args struct {
+ Config jsonMap2[map[string]string] `arg:"--config"`
+ }
+
+ err := parse("", &args)
+ require.NoError(t, err)
+ assert.Empty(t, args.Config)
+}
+
+func TestTextMarshallerUnmarshallerEmptyPointer(t *testing.T) {
+ // a slight variant on https://github.com/alexflint/go-arg/issues/184
+ var args struct {
+ Config *jsonMap2[map[string]string] `arg:"--config"`
+ }
+
+ err := parse("", &args)
+ require.NoError(t, err)
+ assert.Nil(t, args.Config)
+}
diff --git a/reflect.go b/reflect.go
index cd80be7..b87db2a 100644
--- a/reflect.go
+++ b/reflect.go
@@ -10,7 +10,10 @@ import (
scalar "github.com/alexflint/go-scalar"
)
-var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
+var (
+ textMarshalerType = reflect.TypeOf([]encoding.TextMarshaler{}).Elem()
+ textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
+)
// cardinality tracks how many tokens are expected for a given spec
// - zero is a boolean, which does to expect any value
@@ -74,10 +77,10 @@ func cardinalityOf(t reflect.Type) (cardinality, error) {
}
}
-// isBoolean returns true if the type can be parsed from a single string
+// isBoolean returns true if the type is a boolean or a pointer to a boolean
func isBoolean(t reflect.Type) bool {
switch {
- case t.Implements(textUnmarshalerType):
+ case isTextUnmarshaler(t):
return false
case t.Kind() == reflect.Bool:
return true
@@ -88,6 +91,16 @@ func isBoolean(t reflect.Type) bool {
}
}
+// isTextMarshaler returns true if the type or its pointer implements encoding.TextMarshaler
+func isTextMarshaler(t reflect.Type) bool {
+ return t.Implements(textMarshalerType) || reflect.PtrTo(t).Implements(textMarshalerType)
+}
+
+// isTextUnmarshaler returns true if the type or its pointer implements encoding.TextUnmarshaler
+func isTextUnmarshaler(t reflect.Type) bool {
+ return t.Implements(textUnmarshalerType) || reflect.PtrTo(t).Implements(textUnmarshalerType)
+}
+
// isExported returns true if the struct field name is exported
func isExported(field string) bool {
r, _ := utf8.DecodeRuneInString(field) // returns RuneError for empty string or invalid UTF8
diff --git a/usage_test.go b/usage_test.go
index 1744536..0a7ddd8 100644
--- a/usage_test.go
+++ b/usage_test.go
@@ -50,7 +50,7 @@ Options:
--optimize OPTIMIZE, -O OPTIMIZE
optimization level
--ids IDS Ids
- --values VALUES Values [default: [3.14 42 256]]
+ --values VALUES Values
--workers WORKERS, -w WORKERS
number of workers to start [default: 10, env: WORKERS]
--testenv TESTENV, -a TESTENV [env: TEST_ENV]
@@ -74,7 +74,6 @@ Options:
}
args.Name = "Foo Bar"
args.Value = 42
- args.Values = []float64{3.14, 42, 256}
args.File = &NameDotName{"scratch", "txt"}
p, err := NewParser(Config{Program: "example"}, &args)
require.NoError(t, err)