diff options
| -rw-r--r-- | parse.go | 452 | ||||
| -rw-r--r-- | parse_test.go | 7 | ||||
| -rw-r--r-- | subcommand_test.go | 170 | ||||
| -rw-r--r-- | usage.go | 28 |
4 files changed, 522 insertions, 135 deletions
@@ -13,9 +13,33 @@ import ( scalar "github.com/alexflint/go-scalar" ) +// path represents a sequence of steps to find the output location for an +// argument or subcommand in the final destination struct +type path struct { + root int // index of the destination struct + fields []string // sequence of struct field names to traverse +} + +// String gets a string representation of the given path +func (p path) String() string { + return "args." + strings.Join(p.fields, ".") +} + +// Child gets a new path representing a child of this path. +func (p path) Child(child string) path { + // copy the entire slice of fields to avoid possible slice overwrite + subfields := make([]string, len(p.fields)+1) + copy(subfields, append(p.fields, child)) + return path{ + root: p.root, + fields: subfields, + } +} + // spec represents a command line option type spec struct { - dest reflect.Value + dest path + typ reflect.Type long string short string multiple bool @@ -27,6 +51,14 @@ type spec struct { boolean bool } +// command represents a named subcommand, or the top-level command +type command struct { + name string + dest path + specs []*spec + subcommands []*command +} + // ErrHelp indicates that -h or --help were provided var ErrHelp = errors.New("help requested by user") @@ -79,7 +111,8 @@ type Config struct { // Parser represents a set of command line options with destination values type Parser struct { - specs []*spec + cmd *command + roots []reflect.Value config Config version string description string @@ -102,66 +135,177 @@ type Described interface { } // walkFields calls a function for each field of a struct, recursively expanding struct fields. -func walkFields(v reflect.Value, visit func(field reflect.StructField, val reflect.Value, owner reflect.Type) bool) { - t := v.Type() +func walkFields(t reflect.Type, visit func(field reflect.StructField, owner reflect.Type) bool) { for i := 0; i < t.NumField(); i++ { field := t.Field(i) - val := v.Field(i) - expand := visit(field, val, t) + expand := visit(field, t) if expand && field.Type.Kind() == reflect.Struct { - walkFields(val, visit) + walkFields(field.Type, visit) } } } // NewParser constructs a parser from a list of destination structs func NewParser(config Config, dests ...interface{}) (*Parser, error) { + // first pick a name for the command for use in the usage text + var name string + switch { + case config.Program != "": + name = config.Program + case len(os.Args) > 0: + name = filepath.Base(os.Args[0]) + default: + name = "program" + } + + // construct a parser p := Parser{ + cmd: &command{name: name}, config: config, } + + // make a list of roots for _, dest := range dests { + p.roots = append(p.roots, reflect.ValueOf(dest)) + } + + // process each of the destination values + for i, dest := range dests { + t := reflect.TypeOf(dest) + if t.Kind() != reflect.Ptr { + panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", t)) + } + + cmd, err := cmdFromStruct(name, path{root: i}, t) + if err != nil { + return nil, err + } + p.cmd.specs = append(p.cmd.specs, cmd.specs...) + p.cmd.subcommands = append(p.cmd.subcommands, cmd.subcommands...) + if dest, ok := dest.(Versioned); ok { p.version = dest.Version() } if dest, ok := dest.(Described); ok { p.description = dest.Description() } - v := reflect.ValueOf(dest) - if v.Kind() != reflect.Ptr { - panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", v.Type())) + } + + return &p, nil +} + +func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { + // commands can only be created from pointers to structs + if t.Kind() != reflect.Ptr { + return nil, fmt.Errorf("subcommands must be pointers to structs but %s is a %s", + dest, t.Kind()) + } + + t = t.Elem() + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("subcommands must be pointers to structs but %s is a pointer to %s", + dest, t.Kind()) + } + + cmd := command{ + name: name, + dest: dest, + } + + var errs []string + walkFields(t, func(field reflect.StructField, t reflect.Type) bool { + // Check for the ignore switch in the tag + tag := field.Tag.Get("arg") + if tag == "-" { + return false + } + + // If this is an embedded struct then recurse into its fields + if field.Anonymous && field.Type.Kind() == reflect.Struct { + return true } - v = v.Elem() - if v.Kind() != reflect.Struct { - panic(fmt.Sprintf("%T is not a struct pointer", dest)) + + // duplicate the entire path to avoid slice overwrites + subdest := dest.Child(field.Name) + spec := spec{ + dest: subdest, + long: strings.ToLower(field.Name), + typ: field.Type, } - var errs []string - walkFields(v, func(field reflect.StructField, val reflect.Value, t reflect.Type) bool { - // Check for the ignore switch in the tag - tag := field.Tag.Get("arg") - if tag == "-" { - return false - } + help, exists := field.Tag.Lookup("help") + if exists { + spec.help = help + } - // If this is an embedded struct then recurse into its fields - if field.Anonymous && field.Type.Kind() == reflect.Struct { - return true - } + // Look at the tag + var isSubcommand bool // tracks whether this field is a subcommand + if tag != "" { + for _, key := range strings.Split(tag, ",") { + key = strings.TrimLeft(key, " ") + var value string + if pos := strings.Index(key, ":"); pos != -1 { + value = key[pos+1:] + key = key[:pos] + } - spec := spec{ - long: strings.ToLower(field.Name), - dest: val, - } + switch { + case strings.HasPrefix(key, "---"): + errs = append(errs, fmt.Sprintf("%s.%s: too many hyphens", t.Name(), field.Name)) + case strings.HasPrefix(key, "--"): + spec.long = key[2:] + case strings.HasPrefix(key, "-"): + if len(key) != 2 { + errs = append(errs, fmt.Sprintf("%s.%s: short arguments must be one character only", + t.Name(), field.Name)) + return false + } + spec.short = key[1:] + case key == "required": + spec.required = true + case key == "positional": + spec.positional = true + case key == "separate": + spec.separate = true + case key == "help": // deprecated + spec.help = value + case key == "env": + // Use override name if provided + if value != "" { + spec.env = value + } else { + spec.env = strings.ToUpper(field.Name) + } + case key == "subcommand": + // decide on a name for the subcommand + cmdname := value + if cmdname == "" { + cmdname = strings.ToLower(field.Name) + } + + // parse the subcommand recursively + subcmd, err := cmdFromStruct(cmdname, subdest, field.Type) + if err != nil { + errs = append(errs, err.Error()) + return false + } - help, exists := field.Tag.Lookup("help") - if exists { - spec.help = help + cmd.subcommands = append(cmd.subcommands, subcmd) + isSubcommand = true + default: + errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag)) + return false + } } + } + + // Check whether this field is supported. It's good to do this here rather than + // wait until ParseValue because it means that a program with invalid argument + // fields will always fail regardless of whether the arguments it received + // exercised those fields. + if !isSubcommand { + cmd.specs = append(cmd.specs, &spec) - // Check whether this field is supported. It's good to do this here rather than - // wait until ParseValue because it means that a program with invalid argument - // fields will always fail regardless of whether the arguments it received - // exercised those fields. var parseable bool parseable, spec.boolean, spec.multiple = canParse(field.Type) if !parseable { @@ -169,67 +313,28 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) { t.Name(), field.Name, field.Type.String())) return false } + } - // Look at the tag - if tag != "" { - for _, key := range strings.Split(tag, ",") { - key = strings.TrimLeft(key, " ") - var value string - if pos := strings.Index(key, ":"); pos != -1 { - value = key[pos+1:] - key = key[:pos] - } - - switch { - case strings.HasPrefix(key, "---"): - errs = append(errs, fmt.Sprintf("%s.%s: too many hyphens", t.Name(), field.Name)) - case strings.HasPrefix(key, "--"): - spec.long = key[2:] - case strings.HasPrefix(key, "-"): - if len(key) != 2 { - errs = append(errs, fmt.Sprintf("%s.%s: short arguments must be one character only", - t.Name(), field.Name)) - return false - } - spec.short = key[1:] - case key == "required": - spec.required = true - case key == "positional": - spec.positional = true - case key == "separate": - spec.separate = true - case key == "help": // deprecated - spec.help = value - case key == "env": - // Use override name if provided - if value != "" { - spec.env = value - } else { - spec.env = strings.ToUpper(field.Name) - } - default: - errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag)) - return false - } - } - } - p.specs = append(p.specs, &spec) + // if this was an embedded field then we already returned true up above + return false + }) - // if this was an embedded field then we already returned true up above - return false - }) + if len(errs) > 0 { + return nil, errors.New(strings.Join(errs, "\n")) + } - if len(errs) > 0 { - return nil, errors.New(strings.Join(errs, "\n")) + // check that we don't have both positionals and subcommands + var hasPositional bool + for _, spec := range cmd.specs { + if spec.positional { + hasPositional = true } } - if p.config.Program == "" { - p.config.Program = "program" - if len(os.Args) > 0 { - p.config.Program = filepath.Base(os.Args[0]) - } + if hasPositional && len(cmd.subcommands) > 0 { + return nil, fmt.Errorf("%T cannot have both subcommands and positional arguments", t) } - return &p, nil + + return &cmd, nil } // Parse processes the given command line option, storing the results in the field @@ -249,30 +354,11 @@ func (p *Parser) Parse(args []string) error { } // Process all command line arguments - return process(p.specs, args) + return p.process(args) } -// process goes through arguments one-by-one, parses them, and assigns the result to -// the underlying struct field -func process(specs []*spec, args []string) error { - // track the options we have seen - wasPresent := make(map[*spec]bool) - - // construct a map from --option to spec - optionMap := make(map[string]*spec) - for _, spec := range specs { - if spec.positional { - continue - } - if spec.long != "" { - optionMap[spec.long] = spec - } - if spec.short != "" { - optionMap[spec.short] = spec - } - } - - // deal with environment vars +// process environment vars for the given arguments +func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error { for _, spec := range specs { if spec.env == "" { continue @@ -294,7 +380,7 @@ func process(specs []*spec, args []string) error { err, ) } - if err = setSlice(spec.dest, values, !spec.separate); err != nil { + if err = setSlice(p.writable(spec.dest), values, !spec.separate); err != nil { return fmt.Errorf( "error processing environment variable %s with multiple values: %v", spec.env, @@ -302,13 +388,35 @@ func process(specs []*spec, args []string) error { ) } } else { - if err := scalar.ParseValue(spec.dest, value); err != nil { + if err := scalar.ParseValue(p.writable(spec.dest), value); err != nil { return fmt.Errorf("error processing environment variable %s: %v", spec.env, err) } } wasPresent[spec] = true } + return nil +} + +// process goes through arguments one-by-one, parses them, and assigns the result to +// the underlying struct field +func (p *Parser) process(args []string) error { + // track the options we have seen + wasPresent := make(map[*spec]bool) + + // union of specs for the chain of subcommands encountered so far + curCmd := p.cmd + + // make a copy of the specs because we will add to this list each time we expand a subcommand + specs := make([]*spec, len(curCmd.specs)) + copy(specs, curCmd.specs) + + // deal with environment vars + err := p.captureEnvVars(specs, wasPresent) + if err != nil { + return err + } + // process each string from the command line var allpositional bool var positionals []string @@ -322,7 +430,32 @@ func process(specs []*spec, args []string) error { } if !isFlag(arg) || allpositional { - positionals = append(positionals, arg) + // each subcommand can have either subcommands or positionals, but not both + if len(curCmd.subcommands) == 0 { + positionals = append(positionals, arg) + continue + } + + // if we have a subcommand then make sure it is valid for the current context + subcmd := findSubcommand(curCmd.subcommands, arg) + if subcmd == nil { + return fmt.Errorf("invalid subcommand: %s", arg) + } + + // instantiate the field to point to a new struct + v := p.writable(subcmd.dest) + v.Set(reflect.New(v.Type().Elem())) // we already checked that all subcommands are struct pointers + + // add the new options to the set of allowed options + specs = append(specs, subcmd.specs...) + + // capture environment vars for these new options + err := p.captureEnvVars(subcmd.specs, wasPresent) + if err != nil { + return err + } + + curCmd = subcmd continue } @@ -334,9 +467,10 @@ func process(specs []*spec, args []string) error { opt = opt[:pos] } - // lookup the spec for this option - spec, ok := optionMap[opt] - if !ok { + // lookup the spec for this option (note that the "specs" slice changes as + // we expand subcommands so it is better not to use a map) + spec := findOption(specs, opt) + if spec == nil { return fmt.Errorf("unknown argument %s", arg) } wasPresent[spec] = true @@ -355,7 +489,7 @@ func process(specs []*spec, args []string) error { } else { values = append(values, value) } - err := setSlice(spec.dest, values, !spec.separate) + err := setSlice(p.writable(spec.dest), values, !spec.separate) if err != nil { return fmt.Errorf("error processing %s: %v", arg, err) } @@ -373,14 +507,14 @@ func process(specs []*spec, args []string) error { if i+1 == len(args) { return fmt.Errorf("missing value for %s", arg) } - if !nextIsNumeric(spec.dest.Type(), args[i+1]) && isFlag(args[i+1]) { + if !nextIsNumeric(spec.typ, args[i+1]) && isFlag(args[i+1]) { return fmt.Errorf("missing value for %s", arg) } value = args[i+1] i++ } - err := scalar.ParseValue(spec.dest, value) + err := scalar.ParseValue(p.writable(spec.dest), value) if err != nil { return fmt.Errorf("error processing %s: %v", arg, err) } @@ -396,13 +530,13 @@ func process(specs []*spec, args []string) error { } wasPresent[spec] = true if spec.multiple { - err := setSlice(spec.dest, positionals, true) + err := setSlice(p.writable(spec.dest), positionals, true) if err != nil { return fmt.Errorf("error processing %s: %v", spec.long, err) } positionals = nil } else { - err := scalar.ParseValue(spec.dest, positionals[0]) + err := scalar.ParseValue(p.writable(spec.dest), positionals[0]) if err != nil { return fmt.Errorf("error processing %s: %v", spec.long, err) } @@ -445,6 +579,55 @@ func isFlag(s string) bool { return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != "" } +// readable returns a reflect.Value corresponding to the current value for the +// given +func (p *Parser) readable(dest path) reflect.Value { + v := p.roots[dest.root] + for _, field := range dest.fields { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return reflect.Value{} + } + v = v.Elem() + } + + v = v.FieldByName(field) + if !v.IsValid() { + // it is appropriate to panic here because this can only happen due to + // an internal bug in this library (since we construct the path ourselves + // by reflecting on the same struct) + panic(fmt.Errorf("error resolving path %v: %v has no field named %v", + dest.fields, v.Type(), field)) + } + } + return v +} + +// writable trav.patherses the destination struct to find the destination to +// which the value of the given spec should be written. It fills in null +// structs with pointers to the zero value for that struct. +func (p *Parser) writable(dest path) reflect.Value { + v := p.roots[dest.root] + for _, field := range dest.fields { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + + v = v.FieldByName(field) + if !v.IsValid() { + // it is appropriate to panic here because this can only happen due to + // an internal bug in this library (since we construct the path ourselves + // by reflecting on the same struct) + panic(fmt.Errorf("error resolving path %v: %v has no field named %v", + dest.fields, v.Type(), field)) + } + } + return v +} + // parse a value as the appropriate type and store it in the struct func setSlice(dest reflect.Value, values []string, trunc bool) error { if !dest.CanSet() { @@ -529,3 +712,26 @@ func isBoolean(t reflect.Type) bool { return false } } + +// findOption finds an option from its name, or returns null if no spec is found +func findOption(specs []*spec, name string) *spec { + for _, spec := range specs { + if spec.positional { + continue + } + if spec.long == name || spec.short == name { + return spec + } + } + return nil +} + +// findSubcommand finds a subcommand using its name, or returns null if no subcommand is found +func findSubcommand(cmds []*command, name string) *command { + for _, cmd := range cmds { + if cmd.name == name { + return cmd + } + } + return nil +} diff --git a/parse_test.go b/parse_test.go index 9aad2e3..94cf21a 100644 --- a/parse_test.go +++ b/parse_test.go @@ -462,11 +462,10 @@ func TestPanicOnNonPointer(t *testing.T) { }) } -func TestPanicOnNonStruct(t *testing.T) { +func TestErrorOnNonStruct(t *testing.T) { var args string - assert.Panics(t, func() { - _ = parse("", &args) - }) + err := parse("", &args) + assert.Error(t, err) } func TestUnsupportedType(t *testing.T) { diff --git a/subcommand_test.go b/subcommand_test.go new file mode 100644 index 0000000..25689a4 --- /dev/null +++ b/subcommand_test.go @@ -0,0 +1,170 @@ +package arg + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// This file contains tests for parse.go but I decided to put them here +// since that file is getting large + +func TestSubcommandNotAPointer(t *testing.T) { + var args struct { + A string `arg:"subcommand"` + } + _, err := NewParser(Config{}, &args) + assert.Error(t, err) +} + +func TestSubcommandNotAPointerToStruct(t *testing.T) { + var args struct { + A struct{} `arg:"subcommand"` + } + _, err := NewParser(Config{}, &args) + assert.Error(t, err) +} + +func TestPositionalAndSubcommandNotAllowed(t *testing.T) { + var args struct { + A string `arg:"positional"` + B struct{} `arg:"subcommand"` + } + _, err := NewParser(Config{}, &args) + assert.Error(t, err) +} + +func TestMinimalSubcommand(t *testing.T) { + type listCmd struct { + } + var args struct { + List *listCmd `arg:"subcommand"` + } + err := parse("list", &args) + require.NoError(t, err) + assert.NotNil(t, args.List) +} + +func TestNamedSubcommand(t *testing.T) { + type listCmd struct { + } + var args struct { + List *listCmd `arg:"subcommand:ls"` + } + err := parse("ls", &args) + require.NoError(t, err) + assert.NotNil(t, args.List) +} + +func TestEmptySubcommand(t *testing.T) { + type listCmd struct { + } + var args struct { + List *listCmd `arg:"subcommand"` + } + err := parse("", &args) + require.NoError(t, err) + assert.Nil(t, args.List) +} + +func TestTwoSubcommands(t *testing.T) { + type getCmd struct { + } + type listCmd struct { + } + var args struct { + Get *getCmd `arg:"subcommand"` + List *listCmd `arg:"subcommand"` + } + err := parse("list", &args) + require.NoError(t, err) + assert.Nil(t, args.Get) + assert.NotNil(t, args.List) +} + +func TestSubcommandsWithOptions(t *testing.T) { + type getCmd struct { + Name string + } + type listCmd struct { + Limit int + } + type cmd struct { + Verbose bool + Get *getCmd `arg:"subcommand"` + List *listCmd `arg:"subcommand"` + } + + { + var args cmd + err := parse("list", &args) + require.NoError(t, err) + assert.Nil(t, args.Get) + assert.NotNil(t, args.List) + } + + { + var args cmd + err := parse("list --limit 3", &args) + require.NoError(t, err) + assert.Nil(t, args.Get) + assert.NotNil(t, args.List) + assert.Equal(t, args.List.Limit, 3) + } + + { + var args cmd + err := parse("list --limit 3 --verbose", &args) + require.NoError(t, err) + assert.Nil(t, args.Get) + assert.NotNil(t, args.List) + assert.Equal(t, args.List.Limit, 3) + assert.True(t, args.Verbose) + } + + { + var args cmd + err := parse("list --verbose --limit 3", &args) + require.NoError(t, err) + assert.Nil(t, args.Get) + assert.NotNil(t, args.List) + assert.Equal(t, args.List.Limit, 3) + assert.True(t, args.Verbose) + } + + { + var args cmd + err := parse("--verbose list --limit 3", &args) + require.NoError(t, err) + assert.Nil(t, args.Get) + assert.NotNil(t, args.List) + assert.Equal(t, args.List.Limit, 3) + assert.True(t, args.Verbose) + } + + { + var args cmd + err := parse("get", &args) + require.NoError(t, err) + assert.NotNil(t, args.Get) + assert.Nil(t, args.List) + } + + { + var args cmd + err := parse("get --name test", &args) + require.NoError(t, err) + assert.NotNil(t, args.Get) + assert.Nil(t, args.List) + assert.Equal(t, args.Get.Name, "test") + } +} + +func TestNestedSubcommands(t *testing.T) { + // tree of subcommands +} + +func TestSubcommandsWithPositionals(t *testing.T) { + // subcommands with positional arguments +} @@ -22,7 +22,7 @@ func (p *Parser) Fail(msg string) { // WriteUsage writes usage information to the given writer func (p *Parser) WriteUsage(w io.Writer) { var positionals, options []*spec - for _, spec := range p.specs { + for _, spec := range p.cmd.specs { if spec.positional { positionals = append(positionals, spec) } else { @@ -34,7 +34,7 @@ func (p *Parser) WriteUsage(w io.Writer) { fmt.Fprintln(w, p.version) } - fmt.Fprintf(w, "Usage: %s", p.config.Program) + fmt.Fprintf(w, "Usage: %s", p.cmd.name) // write the option component of the usage message for _, spec := range options { @@ -72,7 +72,7 @@ func (p *Parser) WriteUsage(w io.Writer) { // WriteHelp writes the usage string followed by the full help string for each option func (p *Parser) WriteHelp(w io.Writer) { var positionals, options []*spec - for _, spec := range p.specs { + for _, spec := range p.cmd.specs { if spec.positional { positionals = append(positionals, spec) } else { @@ -106,17 +106,26 @@ func (p *Parser) WriteHelp(w io.Writer) { // write the list of options fmt.Fprint(w, "\nOptions:\n") for _, spec := range options { - printOption(w, spec) + p.printOption(w, spec) } // write the list of built in options - printOption(w, &spec{boolean: true, long: "help", short: "h", help: "display this help and exit"}) + p.printOption(w, &spec{ + boolean: true, + long: "help", + short: "h", + help: "display this help and exit", + }) if p.version != "" { - printOption(w, &spec{boolean: true, long: "version", help: "display version and exit"}) + p.printOption(w, &spec{ + boolean: true, + long: "version", + help: "display version and exit", + }) } } -func printOption(w io.Writer, spec *spec) { +func (p *Parser) printOption(w io.Writer, spec *spec) { left := " " + synopsis(spec, "--"+spec.long) if spec.short != "" { left += ", " + synopsis(spec, "-"+spec.short) @@ -131,7 +140,10 @@ func printOption(w io.Writer, spec *spec) { fmt.Fprint(w, spec.help) } // If spec.dest is not the zero value then a default value has been added. - v := spec.dest + var v reflect.Value + if len(spec.dest.fields) > 0 { + v = p.readable(spec.dest) + } if v.IsValid() { z := reflect.Zero(v.Type()) if (v.Type().Comparable() && z.Type().Comparable() && v.Interface() != z.Interface()) || v.Kind() == reflect.Slice && !v.IsNil() { |
