diff options
Diffstat (limited to 'parse.go')
| -rw-r--r-- | parse.go | 68 |
1 files changed, 62 insertions, 6 deletions
@@ -1,6 +1,7 @@ package arg import ( + "encoding" "encoding/csv" "errors" "fmt" @@ -54,6 +55,7 @@ type spec struct { help string env string boolean bool + defaultVal string // default value for this option } // command represents a named subcommand, or the top-level command @@ -192,6 +194,22 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) { if err != nil { return nil, err } + + // add nonzero field values as defaults + for _, spec := range cmd.specs { + if v := p.val(spec.dest); v.IsValid() && !isZero(v) { + if defaultVal, ok := v.Interface().(encoding.TextMarshaler); ok { + str, err := defaultVal.MarshalText() + if err != nil { + return nil, fmt.Errorf("%v: error marshaling default value to string: %v", spec.dest, err) + } + spec.defaultVal = string(str) + } else { + spec.defaultVal = fmt.Sprintf("%v", v) + } + } + } + p.cmd.specs = append(p.cmd.specs, cmd.specs...) p.cmd.subcommands = append(p.cmd.subcommands, cmd.subcommands...) @@ -250,6 +268,11 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { spec.help = help } + defaultVal, hasDefault := field.Tag.Lookup("default") + if hasDefault { + spec.defaultVal = defaultVal + } + // Look at the tag var isSubcommand bool // tracks whether this field is a subcommand if tag != "" { @@ -274,6 +297,11 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { } spec.short = key[1:] case key == "required": + if hasDefault { + errs = append(errs, fmt.Sprintf("%s.%s: 'required' cannot be used when a default value is specified", + t.Name(), field.Name)) + return false + } spec.required = true case key == "positional": spec.positional = true @@ -328,6 +356,11 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { t.Name(), field.Name, field.Type.String())) return false } + if spec.multiple && hasDefault { + errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice fields", + t.Name(), field.Name)) + return false + } } // if this was an embedded field then we already returned true up above @@ -570,15 +603,26 @@ func (p *Parser) process(args []string) error { return fmt.Errorf("too many positional arguments at '%s'", positionals[0]) } - // finally check that all the required args were provided + // fill in defaults and check that all the required args were provided for _, spec := range specs { - if spec.required && !wasPresent[spec] { - name := spec.long - if !spec.positional { - name = "--" + spec.long - } + if wasPresent[spec] { + continue + } + + name := spec.long + if !spec.positional { + name = "--" + spec.long + } + + if spec.required { return fmt.Errorf("%s is required", name) } + if spec.defaultVal != "" { + err := scalar.ParseValue(p.val(spec.dest), spec.defaultVal) + if err != nil { + return fmt.Errorf("error processing default value for %s: %v", name, err) + } + } } return nil @@ -679,3 +723,15 @@ func findSubcommand(cmds []*command, name string) *command { } return nil } + +// isZero returns true if v contains the zero value for its type +func isZero(v reflect.Value) bool { + t := v.Type() + if t.Kind() == reflect.Slice { + return v.IsNil() + } + if !t.Comparable() { + return false + } + return v.Interface() == reflect.Zero(t).Interface() +} |
