summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.travis.yml5
-rw-r--r--README.md15
-rw-r--r--example_test.go5
-rw-r--r--parse.go68
-rw-r--r--parse_test.go77
-rw-r--r--usage.go36
-rw-r--r--usage_test.go27
7 files changed, 179 insertions, 54 deletions
diff --git a/.travis.yml b/.travis.yml
index d2a00fd..87ef507 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,10 +1,7 @@
language: go
go:
- - "1.10"
- "1.12"
- - tip
-env:
- - GO111MODULE=on # will only be used in go 1.11
+ - "1.13"
before_install:
- go get github.com/axw/gocov/gocov
- go get github.com/mattn/goveralls
diff --git a/README.md b/README.md
index c516c51..1f02559 100644
--- a/README.md
+++ b/README.md
@@ -142,10 +142,20 @@ Options:
```go
var args struct {
+ Foo string `default:"abc"`
+ Bar bool
+}
+arg.MustParse(&args)
+```
+
+### Default values (before v1.2)
+
+```go
+var args struct {
Foo string
Bar bool
}
-args.Foo = "default value"
+arg.Foo = "abc"
arg.MustParse(&args)
```
@@ -307,9 +317,8 @@ func (n *NameDotName) MarshalText() ([]byte, error) {
func main() {
var args struct {
- Name NameDotName
+ Name NameDotName `default:"file.txt"`
}
- args.Name = NameDotName{"file", "txt"} // set default value
arg.MustParse(&args)
fmt.Printf("%#v\n", args.Name)
}
diff --git a/example_test.go b/example_test.go
index 2188253..f71fbeb 100644
--- a/example_test.go
+++ b/example_test.go
@@ -30,12 +30,11 @@ func Example_defaultValues() {
os.Args = split("./example")
var args struct {
- Foo string
+ Foo string `default:"abc"`
}
- args.Foo = "default value"
MustParse(&args)
fmt.Println(args.Foo)
- // output: default value
+ // output: abc
}
// This example demonstrates arguments that are required
diff --git a/parse.go b/parse.go
index a29258a..bc156df 100644
--- a/parse.go
+++ b/parse.go
@@ -1,6 +1,7 @@
package arg
import (
+ "encoding"
"encoding/csv"
"errors"
"fmt"
@@ -54,6 +55,7 @@ type spec struct {
help string
env string
boolean bool
+ defaultVal string // default value for this option
}
// command represents a named subcommand, or the top-level command
@@ -192,6 +194,22 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
if err != nil {
return nil, err
}
+
+ // add nonzero field values as defaults
+ for _, spec := range cmd.specs {
+ if v := p.val(spec.dest); v.IsValid() && !isZero(v) {
+ if defaultVal, ok := v.Interface().(encoding.TextMarshaler); ok {
+ str, err := defaultVal.MarshalText()
+ if err != nil {
+ return nil, fmt.Errorf("%v: error marshaling default value to string: %v", spec.dest, err)
+ }
+ spec.defaultVal = string(str)
+ } else {
+ spec.defaultVal = fmt.Sprintf("%v", v)
+ }
+ }
+ }
+
p.cmd.specs = append(p.cmd.specs, cmd.specs...)
p.cmd.subcommands = append(p.cmd.subcommands, cmd.subcommands...)
@@ -250,6 +268,11 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
spec.help = help
}
+ defaultVal, hasDefault := field.Tag.Lookup("default")
+ if hasDefault {
+ spec.defaultVal = defaultVal
+ }
+
// Look at the tag
var isSubcommand bool // tracks whether this field is a subcommand
if tag != "" {
@@ -274,6 +297,11 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
}
spec.short = key[1:]
case key == "required":
+ if hasDefault {
+ errs = append(errs, fmt.Sprintf("%s.%s: 'required' cannot be used when a default value is specified",
+ t.Name(), field.Name))
+ return false
+ }
spec.required = true
case key == "positional":
spec.positional = true
@@ -328,6 +356,11 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
t.Name(), field.Name, field.Type.String()))
return false
}
+ if spec.multiple && hasDefault {
+ errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice fields",
+ t.Name(), field.Name))
+ return false
+ }
}
// if this was an embedded field then we already returned true up above
@@ -570,15 +603,26 @@ func (p *Parser) process(args []string) error {
return fmt.Errorf("too many positional arguments at '%s'", positionals[0])
}
- // finally check that all the required args were provided
+ // fill in defaults and check that all the required args were provided
for _, spec := range specs {
- if spec.required && !wasPresent[spec] {
- name := spec.long
- if !spec.positional {
- name = "--" + spec.long
- }
+ if wasPresent[spec] {
+ continue
+ }
+
+ name := spec.long
+ if !spec.positional {
+ name = "--" + spec.long
+ }
+
+ if spec.required {
return fmt.Errorf("%s is required", name)
}
+ if spec.defaultVal != "" {
+ err := scalar.ParseValue(p.val(spec.dest), spec.defaultVal)
+ if err != nil {
+ return fmt.Errorf("error processing default value for %s: %v", name, err)
+ }
+ }
}
return nil
@@ -679,3 +723,15 @@ func findSubcommand(cmds []*command, name string) *command {
}
return nil
}
+
+// isZero returns true if v contains the zero value for its type
+func isZero(v reflect.Value) bool {
+ t := v.Type()
+ if t.Kind() == reflect.Slice {
+ return v.IsNil()
+ }
+ if !t.Comparable() {
+ return false
+ }
+ return v.Interface() == reflect.Zero(t).Interface()
+}
diff --git a/parse_test.go b/parse_test.go
index 5909472..47e9ccd 100644
--- a/parse_test.go
+++ b/parse_test.go
@@ -1057,3 +1057,80 @@ func TestMultipleTerminates(t *testing.T) {
assert.Equal(t, []string{"a", "b"}, args.X)
assert.Equal(t, "c", args.Y)
}
+
+func TestDefaultOptionValues(t *testing.T) {
+ var args struct {
+ A int `default:"123"`
+ B *int `default:"123"`
+ C string `default:"abc"`
+ D *string `default:"abc"`
+ E float64 `default:"1.23"`
+ F *float64 `default:"1.23"`
+ G bool `default:"true"`
+ H *bool `default:"true"`
+ }
+
+ err := parse("--c=xyz --e=4.56", &args)
+ require.NoError(t, err)
+
+ assert.Equal(t, 123, args.A)
+ assert.Equal(t, 123, *args.B)
+ assert.Equal(t, "xyz", args.C)
+ assert.Equal(t, "abc", *args.D)
+ assert.Equal(t, 4.56, args.E)
+ assert.Equal(t, 1.23, *args.F)
+ assert.True(t, args.G)
+ assert.True(t, args.G)
+}
+
+func TestDefaultUnparseable(t *testing.T) {
+ var args struct {
+ A int `default:"x"`
+ }
+
+ err := parse("", &args)
+ assert.EqualError(t, err, `error processing default value for --a: strconv.ParseInt: parsing "x": invalid syntax`)
+}
+
+func TestDefaultPositionalValues(t *testing.T) {
+ var args struct {
+ A int `arg:"positional" default:"123"`
+ B *int `arg:"positional" default:"123"`
+ C string `arg:"positional" default:"abc"`
+ D *string `arg:"positional" default:"abc"`
+ E float64 `arg:"positional" default:"1.23"`
+ F *float64 `arg:"positional" default:"1.23"`
+ G bool `arg:"positional" default:"true"`
+ H *bool `arg:"positional" default:"true"`
+ }
+
+ err := parse("456 789", &args)
+ require.NoError(t, err)
+
+ assert.Equal(t, 456, args.A)
+ assert.Equal(t, 789, *args.B)
+ assert.Equal(t, "abc", args.C)
+ assert.Equal(t, "abc", *args.D)
+ assert.Equal(t, 1.23, args.E)
+ assert.Equal(t, 1.23, *args.F)
+ assert.True(t, args.G)
+ assert.True(t, args.G)
+}
+
+func TestDefaultValuesNotAllowedWithRequired(t *testing.T) {
+ var args struct {
+ A int `arg:"required" default:"123"` // required not allowed with default!
+ }
+
+ err := parse("", &args)
+ assert.EqualError(t, err, ".A: 'required' cannot be used when a default value is specified")
+}
+
+func TestDefaultValuesNotAllowedWithSlice(t *testing.T) {
+ var args struct {
+ A []int `default:"123"` // required not allowed with default!
+ }
+
+ err := parse("", &args)
+ assert.EqualError(t, err, ".A: default values are not supported for slice fields")
+}
diff --git a/usage.go b/usage.go
index 33b6184..43db703 100644
--- a/usage.go
+++ b/usage.go
@@ -1,11 +1,9 @@
package arg
import (
- "encoding"
"fmt"
"io"
"os"
- "reflect"
"strings"
)
@@ -94,7 +92,7 @@ func (p *Parser) writeUsageForCommand(w io.Writer, cmd *command) {
fmt.Fprint(w, "\n")
}
-func printTwoCols(w io.Writer, left, help string, defaultVal *string) {
+func printTwoCols(w io.Writer, left, help string, defaultVal string) {
lhs := " " + left
fmt.Fprint(w, lhs)
if help != "" {
@@ -105,8 +103,8 @@ func printTwoCols(w io.Writer, left, help string, defaultVal *string) {
}
fmt.Fprint(w, help)
}
- if defaultVal != nil {
- fmt.Fprintf(w, " [default: %s]", *defaultVal)
+ if defaultVal != "" {
+ fmt.Fprintf(w, " [default: %s]", defaultVal)
}
fmt.Fprint(w, "\n")
}
@@ -136,7 +134,7 @@ func (p *Parser) writeHelpForCommand(w io.Writer, cmd *command) {
if len(positionals) > 0 {
fmt.Fprint(w, "\nPositional arguments:\n")
for _, spec := range positionals {
- printTwoCols(w, strings.ToUpper(spec.long), spec.help, nil)
+ printTwoCols(w, strings.ToUpper(spec.long), spec.help, "")
}
}
@@ -165,7 +163,7 @@ func (p *Parser) writeHelpForCommand(w io.Writer, cmd *command) {
if len(cmd.subcommands) > 0 {
fmt.Fprint(w, "\nCommands:\n")
for _, subcmd := range cmd.subcommands {
- printTwoCols(w, subcmd.name, subcmd.help, nil)
+ printTwoCols(w, subcmd.name, subcmd.help, "")
}
}
}
@@ -175,29 +173,7 @@ func (p *Parser) printOption(w io.Writer, spec *spec) {
if spec.short != "" {
left += ", " + synopsis(spec, "-"+spec.short)
}
-
- // If spec.dest is not the zero value then a default value has been added.
- var v reflect.Value
- if len(spec.dest.fields) > 0 {
- v = p.val(spec.dest)
- }
-
- var defaultVal *string
- if v.IsValid() {
- z := reflect.Zero(v.Type())
- if (v.Type().Comparable() && z.Type().Comparable() && v.Interface() != z.Interface()) || v.Kind() == reflect.Slice && !v.IsNil() {
- if scalar, ok := v.Interface().(encoding.TextMarshaler); ok {
- if value, err := scalar.MarshalText(); err != nil {
- defaultVal = ptrTo(fmt.Sprintf("error: %v", err))
- } else {
- defaultVal = ptrTo(fmt.Sprintf("%v", string(value)))
- }
- } else {
- defaultVal = ptrTo(fmt.Sprintf("%v", v))
- }
- }
- }
- printTwoCols(w, left, spec.help, defaultVal)
+ printTwoCols(w, left, spec.help, spec.defaultVal)
}
func synopsis(spec *spec, form string) string {
diff --git a/usage_test.go b/usage_test.go
index fc0b8c5..d9d33f0 100644
--- a/usage_test.go
+++ b/usage_test.go
@@ -96,28 +96,39 @@ func (n *MyEnum) MarshalText() ([]byte, error) {
return nil, errors.New("There was a problem")
}
-func TestUsageError(t *testing.T) {
- expectedHelp := `Usage: example [--name NAME]
+func TestUsageWithDefaults(t *testing.T) {
+ expectedHelp := `Usage: example [--label LABEL] [--content CONTENT]
Options:
- --name NAME [default: error: There was a problem]
+ --label LABEL [default: cat]
+ --content CONTENT [default: dog]
--help, -h display this help and exit
`
var args struct {
- Name *MyEnum
+ Label string
+ Content string `default:"dog"`
}
- v := MyEnum(42)
- args.Name = &v
+ args.Label = "cat"
p, err := NewParser(Config{"example"}, &args)
-
- // NB: some might might expect there to be an error here
require.NoError(t, err)
+ args.Label = "should_ignore_this"
+
var help bytes.Buffer
p.WriteHelp(&help)
assert.Equal(t, expectedHelp, help.String())
}
+func TestUsageCannotMarshalToString(t *testing.T) {
+ var args struct {
+ Name *MyEnum
+ }
+ v := MyEnum(42)
+ args.Name = &v
+ _, err := NewParser(Config{"example"}, &args)
+ assert.EqualError(t, err, `args.Name: error marshaling default value to string: There was a problem`)
+}
+
func TestUsageLongPositionalWithHelp_legacyForm(t *testing.T) {
expectedHelp := `Usage: example VERYLONGPOSITIONALWITHHELP