summaryrefslogtreecommitdiff
path: root/parse.go
diff options
context:
space:
mode:
authorAlex Flint <[email protected]>2019-04-30 12:54:28 -0700
committerAlex Flint <[email protected]>2019-04-30 12:54:28 -0700
commit4e977796af5ef0863a674ef468c5036dcca20623 (patch)
tree84d6670b0d0fce9013058e8d1323aedde48c4d20 /parse.go
parentddec9e9e4febd4f367f31297ad2744d683f474b4 (diff)
add recursive expansion of subcommands
Diffstat (limited to 'parse.go')
-rw-r--r--parse.go130
1 files changed, 100 insertions, 30 deletions
diff --git a/parse.go b/parse.go
index b5b76b8..353b365 100644
--- a/parse.go
+++ b/parse.go
@@ -152,7 +152,6 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
if t.Kind() != reflect.Ptr {
panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", t))
}
- t = t.Elem()
cmd, err := cmdFromStruct(name, t, nil, i)
if err != nil {
@@ -172,8 +171,16 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
}
func cmdFromStruct(name string, t reflect.Type, path []string, root int) (*command, error) {
+ // commands can only be created from pointers to structs
+ if t.Kind() != reflect.Ptr {
+ return nil, fmt.Errorf("subcommands must be pointers to structs but args.%s is a %s",
+ strings.Join(path, "."), t.Kind())
+ }
+
+ t = t.Elem()
if t.Kind() != reflect.Struct {
- panic(fmt.Sprintf("%v is not a struct pointer", t))
+ return nil, fmt.Errorf("subcommands must be pointers to structs but args.%s is a pointer to %s",
+ strings.Join(path, "."), t.Kind())
}
var cmd command
@@ -190,9 +197,13 @@ func cmdFromStruct(name string, t reflect.Type, path []string, root int) (*comma
return true
}
+ // duplicate the entire path to avoid slice overwrites
+ subpath := make([]string, len(path)+1)
+ copy(subpath, append(path, field.Name))
+
spec := spec{
root: root,
- path: append(path, field.Name),
+ path: subpath,
long: strings.ToLower(field.Name),
typ: field.Type,
}
@@ -258,7 +269,7 @@ func cmdFromStruct(name string, t reflect.Type, path []string, root int) (*comma
cmdname = strings.ToLower(field.Name)
}
- subcmd, err := cmdFromStruct(cmdname, field.Type, append(path, field.Name), root)
+ subcmd, err := cmdFromStruct(cmdname, field.Type, subpath, root)
if err != nil {
errs = append(errs, err.Error())
return false
@@ -281,6 +292,17 @@ func cmdFromStruct(name string, t reflect.Type, path []string, root int) (*comma
return nil, errors.New(strings.Join(errs, "\n"))
}
+ // check that we don't have both positionals and subcommands
+ var hasPositional bool
+ for _, spec := range cmd.specs {
+ if spec.positional {
+ hasPositional = true
+ }
+ }
+ if hasPositional && len(cmd.subcommands) > 0 {
+ return nil, fmt.Errorf("%T cannot have both subcommands and positional arguments", t)
+ }
+
return &cmd, nil
}
@@ -301,30 +323,11 @@ func (p *Parser) Parse(args []string) error {
}
// Process all command line arguments
- return p.process(p.cmd.specs, args)
+ return p.process(args)
}
-// process goes through arguments one-by-one, parses them, and assigns the result to
-// the underlying struct field
-func (p *Parser) process(specs []*spec, args []string) error {
- // track the options we have seen
- wasPresent := make(map[*spec]bool)
-
- // construct a map from --option to spec
- optionMap := make(map[string]*spec)
- for _, spec := range specs {
- if spec.positional {
- continue
- }
- if spec.long != "" {
- optionMap[spec.long] = spec
- }
- if spec.short != "" {
- optionMap[spec.short] = spec
- }
- }
-
- // deal with environment vars
+// process environment vars for the given arguments
+func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error {
for _, spec := range specs {
if spec.env == "" {
continue
@@ -361,6 +364,28 @@ func (p *Parser) process(specs []*spec, args []string) error {
wasPresent[spec] = true
}
+ return nil
+}
+
+// process goes through arguments one-by-one, parses them, and assigns the result to
+// the underlying struct field
+func (p *Parser) process(args []string) error {
+ // track the options we have seen
+ wasPresent := make(map[*spec]bool)
+
+ // union of specs for the chain of subcommands encountered so far
+ curCmd := p.cmd
+
+ // make a copy of the specs because we will add to this list each time we expand a subcommand
+ specs := make([]*spec, len(curCmd.specs))
+ copy(specs, curCmd.specs)
+
+ // deal with environment vars
+ err := p.captureEnvVars(specs, wasPresent)
+ if err != nil {
+ return err
+ }
+
// process each string from the command line
var allpositional bool
var positionals []string
@@ -374,7 +399,28 @@ func (p *Parser) process(specs []*spec, args []string) error {
}
if !isFlag(arg) || allpositional {
- positionals = append(positionals, arg)
+ // each subcommand can have either subcommands or positionals, but not both
+ if len(curCmd.subcommands) == 0 {
+ positionals = append(positionals, arg)
+ continue
+ }
+
+ // if we have a subcommand then make sure it is valid for the current context
+ subcmd := findSubcommand(curCmd.subcommands, arg)
+ if subcmd == nil {
+ return fmt.Errorf("invalid subcommand: %s", arg)
+ }
+
+ // add the new options to the set of allowed options
+ specs = append(specs, subcmd.specs...)
+
+ // capture environment vars for these new options
+ err := p.captureEnvVars(subcmd.specs, wasPresent)
+ if err != nil {
+ return err
+ }
+
+ curCmd = subcmd
continue
}
@@ -386,9 +432,10 @@ func (p *Parser) process(specs []*spec, args []string) error {
opt = opt[:pos]
}
- // lookup the spec for this option
- spec, ok := optionMap[opt]
- if !ok {
+ // lookup the spec for this option (note that the "specs" slice changes as
+ // we expand subcommands so it is better not to use a map)
+ spec := findOption(specs, opt)
+ if spec == nil {
return fmt.Errorf("unknown argument %s", arg)
}
wasPresent[spec] = true
@@ -630,3 +677,26 @@ func isBoolean(t reflect.Type) bool {
return false
}
}
+
+// findOption finds an option from its name, or returns null if no spec is found
+func findOption(specs []*spec, name string) *spec {
+ for _, spec := range specs {
+ if spec.positional {
+ continue
+ }
+ if spec.long == name || spec.short == name {
+ return spec
+ }
+ }
+ return nil
+}
+
+// findSubcommand finds a subcommand using its name, or returns null if no subcommand is found
+func findSubcommand(cmds []*command, name string) *command {
+ for _, cmd := range cmds {
+ if cmd.name == name {
+ return cmd
+ }
+ }
+ return nil
+}