summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--go.mod8
-rw-r--r--parse.go134
-rw-r--r--parse_test.go102
-rw-r--r--reflect.go17
-rw-r--r--usage.go2
-rw-r--r--usage_test.go37
6 files changed, 232 insertions, 68 deletions
diff --git a/go.mod b/go.mod
index 944b9bc..44ddff5 100644
--- a/go.mod
+++ b/go.mod
@@ -5,4 +5,10 @@ require (
github.com/stretchr/testify v1.7.0
)
-go 1.13
+require (
+ github.com/davecgh/go-spew v1.1.1 // indirect
+ github.com/pmezard/go-difflib v1.0.0 // indirect
+ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
+)
+
+go 1.18
diff --git a/parse.go b/parse.go
index dc87947..dc48455 100644
--- a/parse.go
+++ b/parse.go
@@ -43,18 +43,19 @@ func (p path) Child(f reflect.StructField) path {
// spec represents a command line option
type spec struct {
- dest path
- field reflect.StructField // the struct field from which this option was created
- long string // the --long form for this option, or empty if none
- short string // the -s short form for this option, or empty if none
- cardinality cardinality // determines how many tokens will be present (possible values: zero, one, multiple)
- required bool // if true, this option must be present on the command line
- positional bool // if true, this option will be looked for in the positional flags
- separate bool // if true, each slice and map entry will have its own --flag
- help string // the help text for this option
- env string // the name of the environment variable for this option, or empty for none
- defaultVal string // default value for this option
- placeholder string // name of the data in help
+ dest path
+ field reflect.StructField // the struct field from which this option was created
+ long string // the --long form for this option, or empty if none
+ short string // the -s short form for this option, or empty if none
+ cardinality cardinality // determines how many tokens will be present (possible values: zero, one, multiple)
+ required bool // if true, this option must be present on the command line
+ positional bool // if true, this option will be looked for in the positional flags
+ separate bool // if true, each slice and map entry will have its own --flag
+ help string // the help text for this option
+ env string // the name of the environment variable for this option, or empty for none
+ defaultValue reflect.Value // default value for this option
+ defaultString string // default value for this option, in string form to be displayed in help text
+ placeholder string // name of the data in help
}
// command represents a named subcommand, or the top-level command
@@ -210,18 +211,31 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
return nil, err
}
- // add nonzero field values as defaults
+ // for backwards compatibility, add nonzero field values as defaults
+ // this applies only to the top-level command, not to subcommands (this inconsistency
+ // is the reason that this method for setting default values was deprecated)
for _, spec := range cmd.specs {
- if v := p.val(spec.dest); v.IsValid() && !isZero(v) {
- if defaultVal, ok := v.Interface().(encoding.TextMarshaler); ok {
- str, err := defaultVal.MarshalText()
- if err != nil {
- return nil, fmt.Errorf("%v: error marshaling default value to string: %v", spec.dest, err)
- }
- spec.defaultVal = string(str)
- } else {
- spec.defaultVal = fmt.Sprintf("%v", v)
+ // get the value
+ v := p.val(spec.dest)
+
+ // if the value is the "zero value" (e.g. nil pointer, empty struct) then ignore
+ if isZero(v) {
+ continue
+ }
+
+ // store as a default
+ spec.defaultValue = v
+
+ // we need a string to display in help text
+ // if MarshalText is implemented then use that
+ if m, ok := v.Interface().(encoding.TextMarshaler); ok {
+ s, err := m.MarshalText()
+ if err != nil {
+ return nil, fmt.Errorf("%v: error marshaling default value to string: %v", spec.dest, err)
}
+ spec.defaultString = string(s)
+ } else {
+ spec.defaultString = fmt.Sprintf("%v", v)
}
}
@@ -293,11 +307,6 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
spec.help = help
}
- defaultVal, hasDefault := field.Tag.Lookup("default")
- if hasDefault {
- spec.defaultVal = defaultVal
- }
-
// Look at the tag
var isSubcommand bool // tracks whether this field is a subcommand
for _, key := range strings.Split(tag, ",") {
@@ -324,11 +333,6 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
}
spec.short = key[1:]
case key == "required":
- if hasDefault {
- errs = append(errs, fmt.Sprintf("%s.%s: 'required' cannot be used when a default value is specified",
- t.Name(), field.Name))
- return false
- }
spec.required = true
case key == "positional":
spec.positional = true
@@ -377,27 +381,60 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
spec.placeholder = strings.ToUpper(spec.field.Name)
}
- // Check whether this field is supported. It's good to do this here rather than
+ // if this is a subcommand then we've done everything we need to do
+ if isSubcommand {
+ return false
+ }
+
+ // check whether this field is supported. It's good to do this here rather than
// wait until ParseValue because it means that a program with invalid argument
// fields will always fail regardless of whether the arguments it received
// exercised those fields.
- if !isSubcommand {
- cmd.specs = append(cmd.specs, &spec)
+ var err error
+ spec.cardinality, err = cardinalityOf(field.Type)
+ if err != nil {
+ errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported",
+ t.Name(), field.Name, field.Type.String()))
+ return false
+ }
- var err error
- spec.cardinality, err = cardinalityOf(field.Type)
- if err != nil {
- errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported",
- t.Name(), field.Name, field.Type.String()))
+ defaultString, hasDefault := field.Tag.Lookup("default")
+ if hasDefault {
+ // we do not support default values for maps and slices
+ if spec.cardinality == multiple {
+ errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice or map fields",
+ t.Name(), field.Name))
return false
}
- if spec.cardinality == multiple && hasDefault {
- errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice or map fields",
+
+ // a required field cannot also have a default value
+ if spec.required {
+ errs = append(errs, fmt.Sprintf("%s.%s: 'required' cannot be used when a default value is specified",
t.Name(), field.Name))
return false
}
+
+ // parse the default value
+ spec.defaultString = defaultString
+ if field.Type.Kind() == reflect.Pointer {
+ // here we have a field of type *T and we create a new T, no need to dereference
+ // in order for the value to be settable
+ spec.defaultValue = reflect.New(field.Type.Elem())
+ } else {
+ // here we have a field of type T and we create a new T and then dereference it
+ // so that the resulting value is settable
+ spec.defaultValue = reflect.New(field.Type).Elem()
+ }
+ err := scalar.ParseValue(spec.defaultValue, defaultString)
+ if err != nil {
+ errs = append(errs, fmt.Sprintf("%s.%s: error processing default value: %v", t.Name(), field.Name, err))
+ return false
+ }
}
+ // add the spec to the list of specs
+ cmd.specs = append(cmd.specs, &spec)
+
// if this was an embedded field then we already returned true up above
return false
})
@@ -680,11 +717,14 @@ func (p *Parser) process(args []string) error {
}
return errors.New(msg)
}
- if !p.config.IgnoreDefault && spec.defaultVal != "" {
- err := scalar.ParseValue(p.val(spec.dest), spec.defaultVal)
- if err != nil {
- return fmt.Errorf("error processing default value for %s: %v", name, err)
- }
+
+ if spec.defaultValue.IsValid() && !p.config.IgnoreDefault {
+ // One issue here is that if the user now modifies the value then
+ // the default value stored in the spec will be corrupted. There
+ // is no general way to "deep-copy" values in Go, and we still
+ // support the old-style method for specifying defaults as
+ // Go values assigned directly to the struct field, so we are stuck.
+ p.val(spec.dest).Set(spec.defaultValue)
}
}
diff --git a/parse_test.go b/parse_test.go
index 7e84def..5d38306 100644
--- a/parse_test.go
+++ b/parse_test.go
@@ -2,6 +2,7 @@ package arg
import (
"bytes"
+ "encoding/json"
"fmt"
"net"
"net/mail"
@@ -1396,13 +1397,21 @@ func TestDefaultOptionValues(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, 123, args.A)
- assert.Equal(t, 123, *args.B)
+ if assert.NotNil(t, args.B) {
+ assert.Equal(t, 123, *args.B)
+ }
assert.Equal(t, "xyz", args.C)
- assert.Equal(t, "abc", *args.D)
+ if assert.NotNil(t, args.D) {
+ assert.Equal(t, "abc", *args.D)
+ }
assert.Equal(t, 4.56, args.E)
- assert.Equal(t, 1.23, *args.F)
- assert.True(t, args.G)
+ if assert.NotNil(t, args.F) {
+ assert.Equal(t, 1.23, *args.F)
+ }
assert.True(t, args.G)
+ if assert.NotNil(t, args.H) {
+ assert.True(t, *args.H)
+ }
}
func TestDefaultUnparseable(t *testing.T) {
@@ -1411,7 +1420,7 @@ func TestDefaultUnparseable(t *testing.T) {
}
err := parse("", &args)
- assert.EqualError(t, err, `error processing default value for --a: strconv.ParseInt: parsing "x": invalid syntax`)
+ assert.EqualError(t, err, `.A: error processing default value: strconv.ParseInt: parsing "x": invalid syntax`)
}
func TestDefaultPositionalValues(t *testing.T) {
@@ -1430,13 +1439,21 @@ func TestDefaultPositionalValues(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, 456, args.A)
- assert.Equal(t, 789, *args.B)
+ if assert.NotNil(t, args.B) {
+ assert.Equal(t, 789, *args.B)
+ }
assert.Equal(t, "abc", args.C)
- assert.Equal(t, "abc", *args.D)
+ if assert.NotNil(t, args.D) {
+ assert.Equal(t, "abc", *args.D)
+ }
assert.Equal(t, 1.23, args.E)
- assert.Equal(t, 1.23, *args.F)
- assert.True(t, args.G)
+ if assert.NotNil(t, args.F) {
+ assert.Equal(t, 1.23, *args.F)
+ }
assert.True(t, args.G)
+ if assert.NotNil(t, args.H) {
+ assert.True(t, *args.H)
+ }
}
func TestDefaultValuesNotAllowedWithRequired(t *testing.T) {
@@ -1450,7 +1467,7 @@ func TestDefaultValuesNotAllowedWithRequired(t *testing.T) {
func TestDefaultValuesNotAllowedWithSlice(t *testing.T) {
var args struct {
- A []int `default:"123"` // required not allowed with default!
+ A []int `default:"invalid"` // default values not allowed with slices
}
err := parse("", &args)
@@ -1532,3 +1549,68 @@ func TestMustParsePrintsVersion(t *testing.T) {
assert.Equal(t, 0, *exitCode)
assert.Equal(t, "example 3.2.1\n", b.String())
}
+
+type mapWithUnmarshalText struct {
+ val map[string]string
+}
+
+func (v *mapWithUnmarshalText) UnmarshalText(data []byte) error {
+ return json.Unmarshal(data, &v.val)
+}
+
+func TestTextUnmarshalerEmpty(t *testing.T) {
+ // based on https://github.com/alexflint/go-arg/issues/184
+ var args struct {
+ Config mapWithUnmarshalText `arg:"--config"`
+ }
+
+ err := parse("", &args)
+ require.NoError(t, err)
+ assert.Empty(t, args.Config)
+}
+
+func TestTextUnmarshalerEmptyPointer(t *testing.T) {
+ // a slight variant on https://github.com/alexflint/go-arg/issues/184
+ var args struct {
+ Config *mapWithUnmarshalText `arg:"--config"`
+ }
+
+ err := parse("", &args)
+ require.NoError(t, err)
+ assert.Nil(t, args.Config)
+}
+
+// similar to the above but also implements MarshalText
+type mapWithMarshalText struct {
+ val map[string]string
+}
+
+func (v *mapWithMarshalText) MarshalText(data []byte) error {
+ return json.Unmarshal(data, &v.val)
+}
+
+func (v *mapWithMarshalText) UnmarshalText(data []byte) error {
+ return json.Unmarshal(data, &v.val)
+}
+
+func TestTextMarshalerUnmarshalerEmpty(t *testing.T) {
+ // based on https://github.com/alexflint/go-arg/issues/184
+ var args struct {
+ Config mapWithMarshalText `arg:"--config"`
+ }
+
+ err := parse("", &args)
+ require.NoError(t, err)
+ assert.Empty(t, args.Config)
+}
+
+func TestTextMarshalerUnmarshalerEmptyPointer(t *testing.T) {
+ // a slight variant on https://github.com/alexflint/go-arg/issues/184
+ var args struct {
+ Config *mapWithMarshalText `arg:"--config"`
+ }
+
+ err := parse("", &args)
+ require.NoError(t, err)
+ assert.Nil(t, args.Config)
+}
diff --git a/reflect.go b/reflect.go
index cd80be7..466d65f 100644
--- a/reflect.go
+++ b/reflect.go
@@ -13,9 +13,9 @@ import (
var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
// cardinality tracks how many tokens are expected for a given spec
-// - zero is a boolean, which does to expect any value
-// - one is an ordinary option that will be parsed from a single token
-// - multiple is a slice or map that can accept zero or more tokens
+// - zero is a boolean, which does to expect any value
+// - one is an ordinary option that will be parsed from a single token
+// - multiple is a slice or map that can accept zero or more tokens
type cardinality int
const (
@@ -74,10 +74,10 @@ func cardinalityOf(t reflect.Type) (cardinality, error) {
}
}
-// isBoolean returns true if the type can be parsed from a single string
+// isBoolean returns true if the type is a boolean or a pointer to a boolean
func isBoolean(t reflect.Type) bool {
switch {
- case t.Implements(textUnmarshalerType):
+ case isTextUnmarshaler(t):
return false
case t.Kind() == reflect.Bool:
return true
@@ -88,6 +88,11 @@ func isBoolean(t reflect.Type) bool {
}
}
+// isTextUnmarshaler returns true if the type or its pointer implements encoding.TextUnmarshaler
+func isTextUnmarshaler(t reflect.Type) bool {
+ return t.Implements(textUnmarshalerType) || reflect.PtrTo(t).Implements(textUnmarshalerType)
+}
+
// isExported returns true if the struct field name is exported
func isExported(field string) bool {
r, _ := utf8.DecodeRuneInString(field) // returns RuneError for empty string or invalid UTF8
@@ -97,7 +102,7 @@ func isExported(field string) bool {
// isZero returns true if v contains the zero value for its type
func isZero(v reflect.Value) bool {
t := v.Type()
- if t.Kind() == reflect.Slice || t.Kind() == reflect.Map {
+ if t.Kind() == reflect.Pointer || t.Kind() == reflect.Slice || t.Kind() == reflect.Map || t.Kind() == reflect.Chan || t.Kind() == reflect.Interface {
return v.IsNil()
}
if !t.Comparable() {
diff --git a/usage.go b/usage.go
index 7ba06cc..7a480c3 100644
--- a/usage.go
+++ b/usage.go
@@ -305,7 +305,7 @@ func (p *Parser) printOption(w io.Writer, spec *spec) {
ways = append(ways, synopsis(spec, "-"+spec.short))
}
if len(ways) > 0 {
- printTwoCols(w, strings.Join(ways, ", "), spec.help, spec.defaultVal, spec.env)
+ printTwoCols(w, strings.Join(ways, ", "), spec.help, spec.defaultString, spec.env)
}
}
diff --git a/usage_test.go b/usage_test.go
index fd67fc8..be5894a 100644
--- a/usage_test.go
+++ b/usage_test.go
@@ -50,7 +50,7 @@ Options:
--optimize OPTIMIZE, -O OPTIMIZE
optimization level
--ids IDS Ids
- --values VALUES Values [default: [3.14 42 256]]
+ --values VALUES Values
--workers WORKERS, -w WORKERS
number of workers to start [default: 10, env: WORKERS]
--testenv TESTENV, -a TESTENV [env: TEST_ENV]
@@ -74,7 +74,6 @@ Options:
}
args.Name = "Foo Bar"
args.Value = 42
- args.Values = []float64{3.14, 42, 256}
args.File = &NameDotName{"scratch", "txt"}
p, err := NewParser(Config{Program: "example"}, &args)
require.NoError(t, err)
@@ -506,7 +505,7 @@ Options:
ShortOnly2 string `arg:"-b,--,required" help:"some help2"`
}
p, err := NewParser(Config{Program: "example"}, &args)
- assert.NoError(t, err)
+ require.NoError(t, err)
var help bytes.Buffer
p.WriteHelp(&help)
@@ -633,3 +632,35 @@ error: something went wrong
assert.Equal(t, expectedStdout[1:], b.String())
assert.Equal(t, -1, exitCode)
}
+
+type lengthOf struct {
+ Length int
+}
+
+func (p *lengthOf) UnmarshalText(b []byte) error {
+ p.Length = len(b)
+ return nil
+}
+
+func TestHelpShowsDefaultValueFromOriginalTag(t *testing.T) {
+ // check that the usage text prints the original string from the default tag, not
+ // the serialization of the parsed value
+
+ expectedHelp := `
+Usage: example [--test TEST]
+
+Options:
+ --test TEST [default: some_default_value]
+ --help, -h display this help and exit
+`
+
+ var args struct {
+ Test *lengthOf `default:"some_default_value"`
+ }
+ p, err := NewParser(Config{Program: "example"}, &args)
+ require.NoError(t, err)
+
+ var help bytes.Buffer
+ p.WriteHelp(&help)
+ assert.Equal(t, expectedHelp[1:], help.String())
+}