summaryrefslogtreecommitdiff
path: root/parse.go
diff options
context:
space:
mode:
Diffstat (limited to 'parse.go')
-rw-r--r--parse.go68
1 files changed, 62 insertions, 6 deletions
diff --git a/parse.go b/parse.go
index a29258a..bc156df 100644
--- a/parse.go
+++ b/parse.go
@@ -1,6 +1,7 @@
package arg
import (
+ "encoding"
"encoding/csv"
"errors"
"fmt"
@@ -54,6 +55,7 @@ type spec struct {
help string
env string
boolean bool
+ defaultVal string // default value for this option
}
// command represents a named subcommand, or the top-level command
@@ -192,6 +194,22 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
if err != nil {
return nil, err
}
+
+ // add nonzero field values as defaults
+ for _, spec := range cmd.specs {
+ if v := p.val(spec.dest); v.IsValid() && !isZero(v) {
+ if defaultVal, ok := v.Interface().(encoding.TextMarshaler); ok {
+ str, err := defaultVal.MarshalText()
+ if err != nil {
+ return nil, fmt.Errorf("%v: error marshaling default value to string: %v", spec.dest, err)
+ }
+ spec.defaultVal = string(str)
+ } else {
+ spec.defaultVal = fmt.Sprintf("%v", v)
+ }
+ }
+ }
+
p.cmd.specs = append(p.cmd.specs, cmd.specs...)
p.cmd.subcommands = append(p.cmd.subcommands, cmd.subcommands...)
@@ -250,6 +268,11 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
spec.help = help
}
+ defaultVal, hasDefault := field.Tag.Lookup("default")
+ if hasDefault {
+ spec.defaultVal = defaultVal
+ }
+
// Look at the tag
var isSubcommand bool // tracks whether this field is a subcommand
if tag != "" {
@@ -274,6 +297,11 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
}
spec.short = key[1:]
case key == "required":
+ if hasDefault {
+ errs = append(errs, fmt.Sprintf("%s.%s: 'required' cannot be used when a default value is specified",
+ t.Name(), field.Name))
+ return false
+ }
spec.required = true
case key == "positional":
spec.positional = true
@@ -328,6 +356,11 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
t.Name(), field.Name, field.Type.String()))
return false
}
+ if spec.multiple && hasDefault {
+ errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice fields",
+ t.Name(), field.Name))
+ return false
+ }
}
// if this was an embedded field then we already returned true up above
@@ -570,15 +603,26 @@ func (p *Parser) process(args []string) error {
return fmt.Errorf("too many positional arguments at '%s'", positionals[0])
}
- // finally check that all the required args were provided
+ // fill in defaults and check that all the required args were provided
for _, spec := range specs {
- if spec.required && !wasPresent[spec] {
- name := spec.long
- if !spec.positional {
- name = "--" + spec.long
- }
+ if wasPresent[spec] {
+ continue
+ }
+
+ name := spec.long
+ if !spec.positional {
+ name = "--" + spec.long
+ }
+
+ if spec.required {
return fmt.Errorf("%s is required", name)
}
+ if spec.defaultVal != "" {
+ err := scalar.ParseValue(p.val(spec.dest), spec.defaultVal)
+ if err != nil {
+ return fmt.Errorf("error processing default value for %s: %v", name, err)
+ }
+ }
}
return nil
@@ -679,3 +723,15 @@ func findSubcommand(cmds []*command, name string) *command {
}
return nil
}
+
+// isZero returns true if v contains the zero value for its type
+func isZero(v reflect.Value) bool {
+ t := v.Type()
+ if t.Kind() == reflect.Slice {
+ return v.IsNil()
+ }
+ if !t.Comparable() {
+ return false
+ }
+ return v.Interface() == reflect.Zero(t).Interface()
+}