summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--go.mod2
-rw-r--r--parse.go39
-rw-r--r--parse_test.go68
3 files changed, 103 insertions, 6 deletions
diff --git a/go.mod b/go.mod
index c4c4879..14c6119 100644
--- a/go.mod
+++ b/go.mod
@@ -4,3 +4,5 @@ require (
github.com/alexflint/go-scalar v1.0.0
github.com/stretchr/testify v1.2.2
)
+
+go 1.13
diff --git a/parse.go b/parse.go
index a29258a..d234ed2 100644
--- a/parse.go
+++ b/parse.go
@@ -54,6 +54,7 @@ type spec struct {
help string
env string
boolean bool
+ defaultVal string // default value for this option, only if provided as a struct tag
}
// command represents a named subcommand, or the top-level command
@@ -250,6 +251,11 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
spec.help = help
}
+ defaultVal, hasDefault := field.Tag.Lookup("default")
+ if hasDefault {
+ spec.defaultVal = defaultVal
+ }
+
// Look at the tag
var isSubcommand bool // tracks whether this field is a subcommand
if tag != "" {
@@ -274,6 +280,11 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
}
spec.short = key[1:]
case key == "required":
+ if hasDefault {
+ errs = append(errs, fmt.Sprintf("%s.%s: 'required' cannot be used when a default value is specified",
+ t.Name(), field.Name))
+ return false
+ }
spec.required = true
case key == "positional":
spec.positional = true
@@ -328,6 +339,11 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
t.Name(), field.Name, field.Type.String()))
return false
}
+ if spec.multiple && hasDefault {
+ errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice fields",
+ t.Name(), field.Name))
+ return false
+ }
}
// if this was an embedded field then we already returned true up above
@@ -570,15 +586,26 @@ func (p *Parser) process(args []string) error {
return fmt.Errorf("too many positional arguments at '%s'", positionals[0])
}
- // finally check that all the required args were provided
+ // fill in defaults and check that all the required args were provided
for _, spec := range specs {
- if spec.required && !wasPresent[spec] {
- name := spec.long
- if !spec.positional {
- name = "--" + spec.long
- }
+ if wasPresent[spec] {
+ continue
+ }
+
+ name := spec.long
+ if !spec.positional {
+ name = "--" + spec.long
+ }
+
+ if spec.required {
return fmt.Errorf("%s is required", name)
}
+ if spec.defaultVal != "" {
+ err := scalar.ParseValue(p.val(spec.dest), spec.defaultVal)
+ if err != nil {
+ return fmt.Errorf("error processing default value for %s: %v", name, err)
+ }
+ }
}
return nil
diff --git a/parse_test.go b/parse_test.go
index 5909472..9cd8bce 100644
--- a/parse_test.go
+++ b/parse_test.go
@@ -1057,3 +1057,71 @@ func TestMultipleTerminates(t *testing.T) {
assert.Equal(t, []string{"a", "b"}, args.X)
assert.Equal(t, "c", args.Y)
}
+
+func TestDefaultOptionValues(t *testing.T) {
+ var args struct {
+ A int `default:"123"`
+ B *int `default:"123"`
+ C string `default:"abc"`
+ D *string `default:"abc"`
+ E float64 `default:"1.23"`
+ F *float64 `default:"1.23"`
+ G bool `default:"true"`
+ H *bool `default:"true"`
+ }
+
+ err := parse("--c=xyz --e=4.56", &args)
+ require.NoError(t, err)
+
+ assert.Equal(t, 123, args.A)
+ assert.Equal(t, 123, *args.B)
+ assert.Equal(t, "xyz", args.C)
+ assert.Equal(t, "abc", *args.D)
+ assert.Equal(t, 4.56, args.E)
+ assert.Equal(t, 1.23, *args.F)
+ assert.True(t, args.G)
+ assert.True(t, args.G)
+}
+
+func TestDefaultPositionalValues(t *testing.T) {
+ var args struct {
+ A int `arg:"positional" default:"123"`
+ B *int `arg:"positional" default:"123"`
+ C string `arg:"positional" default:"abc"`
+ D *string `arg:"positional" default:"abc"`
+ E float64 `arg:"positional" default:"1.23"`
+ F *float64 `arg:"positional" default:"1.23"`
+ G bool `arg:"positional" default:"true"`
+ H *bool `arg:"positional" default:"true"`
+ }
+
+ err := parse("456 789", &args)
+ require.NoError(t, err)
+
+ assert.Equal(t, 456, args.A)
+ assert.Equal(t, 789, *args.B)
+ assert.Equal(t, "abc", args.C)
+ assert.Equal(t, "abc", *args.D)
+ assert.Equal(t, 1.23, args.E)
+ assert.Equal(t, 1.23, *args.F)
+ assert.True(t, args.G)
+ assert.True(t, args.G)
+}
+
+func TestDefaultValuesNotAllowedWithRequired(t *testing.T) {
+ var args struct {
+ A int `arg:"required" default:"123"` // required not allowed with default!
+ }
+
+ err := parse("", &args)
+ assert.EqualError(t, err, ".A: 'required' cannot be used when a default value is specified")
+}
+
+func TestDefaultValuesNotAllowedWithSlice(t *testing.T) {
+ var args struct {
+ A []int `default:"123"` // required not allowed with default!
+ }
+
+ err := parse("", &args)
+ assert.EqualError(t, err, ".A: default values are not supported for slice fields")
+}