summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.travis.yml2
-rw-r--r--README.md53
-rw-r--r--example_test.go166
-rw-r--r--parse.go518
-rw-r--r--parse_test.go76
-rw-r--r--reflect.go62
-rw-r--r--reflect_test.go55
-rw-r--r--subcommand.go37
-rw-r--r--subcommand_test.go355
-rw-r--r--usage.go129
10 files changed, 1220 insertions, 233 deletions
diff --git a/.travis.yml b/.travis.yml
index f953225..d2a00fd 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,7 +1,7 @@
language: go
go:
- "1.10"
- - "1.11"
+ - "1.12"
- tip
env:
- GO111MODULE=on # will only be used in go 1.11
diff --git a/README.md b/README.md
index c772994..c516c51 100644
--- a/README.md
+++ b/README.md
@@ -353,6 +353,59 @@ Options:
--help, -h display this help and exit
```
+### Subcommands
+
+*Introduced in `v1.1.0`*
+
+Subcommands are commonly used in tools that wish to group multiple functions into a single program. An example is the `git` tool:
+```shell
+$ git checkout [arguments specific to checking out code]
+$ git commit [arguments specific to committing]
+$ git push [arguments specific to pushing]
+```
+
+The strings "checkout", "commit", and "push" are different from simple positional arguments because the options available to the user change depending on which subcommand they choose.
+
+This can be implemented with `go-arg` as follows:
+
+```go
+type CheckoutCmd struct {
+ Branch string `arg:"positional"`
+ Track bool `arg:"-t"`
+}
+type CommitCmd struct {
+ All bool `arg:"-a"`
+ Message string `arg:"-m"`
+}
+type PushCmd struct {
+ Remote string `arg:"positional"`
+ Branch string `arg:"positional"`
+ SetUpstream bool `arg:"-u"`
+}
+var args struct {
+ Checkout *CheckoutCmd `arg:"subcommand:checkout"`
+ Commit *CommitCmd `arg:"subcommand:commit"`
+ Push *PushCmd `arg:"subcommand:push"`
+ Quiet bool `arg:"-q"` // this flag is global to all subcommands
+}
+
+arg.MustParse(&args)
+
+switch {
+case args.Checkout != nil:
+ fmt.Printf("checkout requested for branch %s\n", args.Checkout.Branch)
+case args.Commit != nil:
+ fmt.Printf("commit requested with message \"%s\"\n", args.Commit.Message)
+case args.Push != nil:
+ fmt.Printf("push requested from %s to %s\n", args.Push.Branch, args.Push.Remote)
+}
+```
+
+Some additional rules apply when working with subcommands:
+* The `subcommand` tag can only be used with fields that are pointers to structs
+* Any struct that contains a subcommand must not contain any positionals
+
+
### API Documentation
https://godoc.org/github.com/alexflint/go-arg
diff --git a/example_test.go b/example_test.go
index 72807a7..2188253 100644
--- a/example_test.go
+++ b/example_test.go
@@ -104,7 +104,7 @@ func Example_multipleMixed() {
}
// This example shows the usage string generated by go-arg
-func Example_usageString() {
+func Example_helpText() {
// These are the args you would pass in on the command line
os.Args = split("./example --help")
@@ -135,3 +135,167 @@ func Example_usageString() {
// optimization level
// --help, -h display this help and exit
}
+
+// This example shows the usage string generated by go-arg when using subcommands
+func Example_helpTextWithSubcommand() {
+ // These are the args you would pass in on the command line
+ os.Args = split("./example --help")
+
+ type getCmd struct {
+ Item string `arg:"positional" help:"item to fetch"`
+ }
+
+ type listCmd struct {
+ Format string `help:"output format"`
+ Limit int
+ }
+
+ var args struct {
+ Verbose bool
+ Get *getCmd `arg:"subcommand" help:"fetch an item and print it"`
+ List *listCmd `arg:"subcommand" help:"list available items"`
+ }
+
+ // This is only necessary when running inside golang's runnable example harness
+ osExit = func(int) {}
+
+ MustParse(&args)
+
+ // output:
+ // Usage: example [--verbose]
+ //
+ // Options:
+ // --verbose
+ // --help, -h display this help and exit
+ //
+ // Commands:
+ // get fetch an item and print it
+ // list list available items
+}
+
+// This example shows the usage string generated by go-arg when using subcommands
+func Example_helpTextForSubcommand() {
+ // These are the args you would pass in on the command line
+ os.Args = split("./example get --help")
+
+ type getCmd struct {
+ Item string `arg:"positional" help:"item to fetch"`
+ }
+
+ type listCmd struct {
+ Format string `help:"output format"`
+ Limit int
+ }
+
+ var args struct {
+ Verbose bool
+ Get *getCmd `arg:"subcommand" help:"fetch an item and print it"`
+ List *listCmd `arg:"subcommand" help:"list available items"`
+ }
+
+ // This is only necessary when running inside golang's runnable example harness
+ osExit = func(int) {}
+
+ MustParse(&args)
+
+ // output:
+ // Usage: example get ITEM
+ //
+ // Positional arguments:
+ // ITEM item to fetch
+ //
+ // Options:
+ // --help, -h display this help and exit
+}
+
+// This example shows the error string generated by go-arg when an invalid option is provided
+func Example_errorText() {
+ // These are the args you would pass in on the command line
+ os.Args = split("./example --optimize INVALID")
+
+ var args struct {
+ Input string `arg:"positional"`
+ Output []string `arg:"positional"`
+ Verbose bool `arg:"-v" help:"verbosity level"`
+ Dataset string `help:"dataset to use"`
+ Optimize int `arg:"-O,help:optimization level"`
+ }
+
+ // This is only necessary when running inside golang's runnable example harness
+ osExit = func(int) {}
+ stderr = os.Stdout
+
+ MustParse(&args)
+
+ // output:
+ // Usage: example [--verbose] [--dataset DATASET] [--optimize OPTIMIZE] INPUT [OUTPUT [OUTPUT ...]]
+ // error: error processing --optimize: strconv.ParseInt: parsing "INVALID": invalid syntax
+}
+
+// This example shows the error string generated by go-arg when an invalid option is provided
+func Example_errorTextForSubcommand() {
+ // These are the args you would pass in on the command line
+ os.Args = split("./example get --count INVALID")
+
+ type getCmd struct {
+ Count int
+ }
+
+ var args struct {
+ Get *getCmd `arg:"subcommand"`
+ }
+
+ // This is only necessary when running inside golang's runnable example harness
+ osExit = func(int) {}
+ stderr = os.Stdout
+
+ MustParse(&args)
+
+ // output:
+ // Usage: example get [--count COUNT]
+ // error: error processing --count: strconv.ParseInt: parsing "INVALID": invalid syntax
+}
+
+// This example demonstrates use of subcommands
+func Example_subcommand() {
+ // These are the args you would pass in on the command line
+ os.Args = split("./example commit -a -m what-this-commit-is-about")
+
+ type CheckoutCmd struct {
+ Branch string `arg:"positional"`
+ Track bool `arg:"-t"`
+ }
+ type CommitCmd struct {
+ All bool `arg:"-a"`
+ Message string `arg:"-m"`
+ }
+ type PushCmd struct {
+ Remote string `arg:"positional"`
+ Branch string `arg:"positional"`
+ SetUpstream bool `arg:"-u"`
+ }
+ var args struct {
+ Checkout *CheckoutCmd `arg:"subcommand:checkout"`
+ Commit *CommitCmd `arg:"subcommand:commit"`
+ Push *PushCmd `arg:"subcommand:push"`
+ Quiet bool `arg:"-q"` // this flag is global to all subcommands
+ }
+
+ // This is only necessary when running inside golang's runnable example harness
+ osExit = func(int) {}
+ stderr = os.Stdout
+
+ MustParse(&args)
+
+ switch {
+ case args.Checkout != nil:
+ fmt.Printf("checkout requested for branch %s\n", args.Checkout.Branch)
+ case args.Commit != nil:
+ fmt.Printf("commit requested with message \"%s\"\n", args.Commit.Message)
+ case args.Push != nil:
+ fmt.Printf("push requested from %s to %s\n", args.Push.Branch, args.Push.Remote)
+ }
+
+ // output:
+ // commit requested with message "what-this-commit-is-about"
+}
diff --git a/parse.go b/parse.go
index e23731c..2a6b7c7 100644
--- a/parse.go
+++ b/parse.go
@@ -1,7 +1,6 @@
package arg
import (
- "encoding"
"encoding/csv"
"errors"
"fmt"
@@ -16,9 +15,36 @@ import (
// to enable monkey-patching during tests
var osExit = os.Exit
+// path represents a sequence of steps to find the output location for an
+// argument or subcommand in the final destination struct
+type path struct {
+ root int // index of the destination struct
+ fields []string // sequence of struct field names to traverse
+}
+
+// String gets a string representation of the given path
+func (p path) String() string {
+ if len(p.fields) == 0 {
+ return "args"
+ }
+ return "args." + strings.Join(p.fields, ".")
+}
+
+// Child gets a new path representing a child of this path.
+func (p path) Child(child string) path {
+ // copy the entire slice of fields to avoid possible slice overwrite
+ subfields := make([]string, len(p.fields)+1)
+ copy(subfields, append(p.fields, child))
+ return path{
+ root: p.root,
+ fields: subfields,
+ }
+}
+
// spec represents a command line option
type spec struct {
- dest reflect.Value
+ dest path
+ typ reflect.Type
long string
short string
multiple bool
@@ -30,6 +56,16 @@ type spec struct {
boolean bool
}
+// command represents a named subcommand, or the top-level command
+type command struct {
+ name string
+ help string
+ dest path
+ specs []*spec
+ subcommands []*command
+ parent *command
+}
+
// ErrHelp indicates that -h or --help were provided
var ErrHelp = errors.New("help requested by user")
@@ -42,18 +78,19 @@ func MustParse(dest ...interface{}) *Parser {
if err != nil {
fmt.Println(err)
osExit(-1)
+ return nil // just in case osExit was monkey-patched
}
err = p.Parse(flags())
switch {
case err == ErrHelp:
- p.WriteHelp(os.Stdout)
+ p.writeHelpForCommand(os.Stdout, p.lastCmd)
osExit(0)
case err == ErrVersion:
fmt.Println(p.version)
osExit(0)
case err != nil:
- p.Fail(err.Error())
+ p.failWithCommand(err.Error(), p.lastCmd)
}
return p
@@ -83,10 +120,14 @@ type Config struct {
// Parser represents a set of command line options with destination values
type Parser struct {
- specs []*spec
+ cmd *command
+ roots []reflect.Value
config Config
version string
description string
+
+ // the following fields change curing processing of command line arguments
+ lastCmd *command
}
// Versioned is the interface that the destination struct should implement to
@@ -106,66 +147,180 @@ type Described interface {
}
// walkFields calls a function for each field of a struct, recursively expanding struct fields.
-func walkFields(v reflect.Value, visit func(field reflect.StructField, val reflect.Value, owner reflect.Type) bool) {
- t := v.Type()
+func walkFields(t reflect.Type, visit func(field reflect.StructField, owner reflect.Type) bool) {
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
- val := v.Field(i)
- expand := visit(field, val, t)
+ expand := visit(field, t)
if expand && field.Type.Kind() == reflect.Struct {
- walkFields(val, visit)
+ walkFields(field.Type, visit)
}
}
}
// NewParser constructs a parser from a list of destination structs
func NewParser(config Config, dests ...interface{}) (*Parser, error) {
+ // first pick a name for the command for use in the usage text
+ var name string
+ switch {
+ case config.Program != "":
+ name = config.Program
+ case len(os.Args) > 0:
+ name = filepath.Base(os.Args[0])
+ default:
+ name = "program"
+ }
+
+ // construct a parser
p := Parser{
+ cmd: &command{name: name},
config: config,
}
+
+ // make a list of roots
for _, dest := range dests {
+ p.roots = append(p.roots, reflect.ValueOf(dest))
+ }
+
+ // process each of the destination values
+ for i, dest := range dests {
+ t := reflect.TypeOf(dest)
+ if t.Kind() != reflect.Ptr {
+ panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", t))
+ }
+
+ cmd, err := cmdFromStruct(name, path{root: i}, t)
+ if err != nil {
+ return nil, err
+ }
+ p.cmd.specs = append(p.cmd.specs, cmd.specs...)
+ p.cmd.subcommands = append(p.cmd.subcommands, cmd.subcommands...)
+
if dest, ok := dest.(Versioned); ok {
p.version = dest.Version()
}
if dest, ok := dest.(Described); ok {
p.description = dest.Description()
}
- v := reflect.ValueOf(dest)
- if v.Kind() != reflect.Ptr {
- panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", v.Type()))
+ }
+
+ return &p, nil
+}
+
+func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
+ // commands can only be created from pointers to structs
+ if t.Kind() != reflect.Ptr {
+ return nil, fmt.Errorf("subcommands must be pointers to structs but %s is a %s",
+ dest, t.Kind())
+ }
+
+ t = t.Elem()
+ if t.Kind() != reflect.Struct {
+ return nil, fmt.Errorf("subcommands must be pointers to structs but %s is a pointer to %s",
+ dest, t.Kind())
+ }
+
+ cmd := command{
+ name: name,
+ dest: dest,
+ }
+
+ var errs []string
+ walkFields(t, func(field reflect.StructField, t reflect.Type) bool {
+ // Check for the ignore switch in the tag
+ tag := field.Tag.Get("arg")
+ if tag == "-" {
+ return false
}
- v = v.Elem()
- if v.Kind() != reflect.Struct {
- panic(fmt.Sprintf("%T is not a struct pointer", dest))
+
+ // If this is an embedded struct then recurse into its fields
+ if field.Anonymous && field.Type.Kind() == reflect.Struct {
+ return true
}
- var errs []string
- walkFields(v, func(field reflect.StructField, val reflect.Value, t reflect.Type) bool {
- // Check for the ignore switch in the tag
- tag := field.Tag.Get("arg")
- if tag == "-" {
- return false
- }
+ // duplicate the entire path to avoid slice overwrites
+ subdest := dest.Child(field.Name)
+ spec := spec{
+ dest: subdest,
+ long: strings.ToLower(field.Name),
+ typ: field.Type,
+ }
- // If this is an embedded struct then recurse into its fields
- if field.Anonymous && field.Type.Kind() == reflect.Struct {
- return true
- }
+ help, exists := field.Tag.Lookup("help")
+ if exists {
+ spec.help = help
+ }
- spec := spec{
- long: strings.ToLower(field.Name),
- dest: val,
- }
+ // Look at the tag
+ var isSubcommand bool // tracks whether this field is a subcommand
+ if tag != "" {
+ for _, key := range strings.Split(tag, ",") {
+ key = strings.TrimLeft(key, " ")
+ var value string
+ if pos := strings.Index(key, ":"); pos != -1 {
+ value = key[pos+1:]
+ key = key[:pos]
+ }
- help, exists := field.Tag.Lookup("help")
- if exists {
- spec.help = help
+ switch {
+ case strings.HasPrefix(key, "---"):
+ errs = append(errs, fmt.Sprintf("%s.%s: too many hyphens", t.Name(), field.Name))
+ case strings.HasPrefix(key, "--"):
+ spec.long = key[2:]
+ case strings.HasPrefix(key, "-"):
+ if len(key) != 2 {
+ errs = append(errs, fmt.Sprintf("%s.%s: short arguments must be one character only",
+ t.Name(), field.Name))
+ return false
+ }
+ spec.short = key[1:]
+ case key == "required":
+ spec.required = true
+ case key == "positional":
+ spec.positional = true
+ case key == "separate":
+ spec.separate = true
+ case key == "help": // deprecated
+ spec.help = value
+ case key == "env":
+ // Use override name if provided
+ if value != "" {
+ spec.env = value
+ } else {
+ spec.env = strings.ToUpper(field.Name)
+ }
+ case key == "subcommand":
+ // decide on a name for the subcommand
+ cmdname := value
+ if cmdname == "" {
+ cmdname = strings.ToLower(field.Name)
+ }
+
+ // parse the subcommand recursively
+ subcmd, err := cmdFromStruct(cmdname, subdest, field.Type)
+ if err != nil {
+ errs = append(errs, err.Error())
+ return false
+ }
+
+ subcmd.parent = &cmd
+ subcmd.help = field.Tag.Get("help")
+
+ cmd.subcommands = append(cmd.subcommands, subcmd)
+ isSubcommand = true
+ default:
+ errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag))
+ return false
+ }
}
+ }
+
+ // Check whether this field is supported. It's good to do this here rather than
+ // wait until ParseValue because it means that a program with invalid argument
+ // fields will always fail regardless of whether the arguments it received
+ // exercised those fields.
+ if !isSubcommand {
+ cmd.specs = append(cmd.specs, &spec)
- // Check whether this field is supported. It's good to do this here rather than
- // wait until ParseValue because it means that a program with invalid argument
- // fields will always fail regardless of whether the arguments it received
- // exercised those fields.
var parseable bool
parseable, spec.boolean, spec.multiple = canParse(field.Type)
if !parseable {
@@ -173,110 +328,50 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
t.Name(), field.Name, field.Type.String()))
return false
}
+ }
- // Look at the tag
- if tag != "" {
- for _, key := range strings.Split(tag, ",") {
- key = strings.TrimLeft(key, " ")
- var value string
- if pos := strings.Index(key, ":"); pos != -1 {
- value = key[pos+1:]
- key = key[:pos]
- }
-
- switch {
- case strings.HasPrefix(key, "---"):
- errs = append(errs, fmt.Sprintf("%s.%s: too many hyphens", t.Name(), field.Name))
- case strings.HasPrefix(key, "--"):
- spec.long = key[2:]
- case strings.HasPrefix(key, "-"):
- if len(key) != 2 {
- errs = append(errs, fmt.Sprintf("%s.%s: short arguments must be one character only",
- t.Name(), field.Name))
- return false
- }
- spec.short = key[1:]
- case key == "required":
- spec.required = true
- case key == "positional":
- spec.positional = true
- case key == "separate":
- spec.separate = true
- case key == "help": // deprecated
- spec.help = value
- case key == "env":
- // Use override name if provided
- if value != "" {
- spec.env = value
- } else {
- spec.env = strings.ToUpper(field.Name)
- }
- default:
- errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag))
- return false
- }
- }
- }
- p.specs = append(p.specs, &spec)
+ // if this was an embedded field then we already returned true up above
+ return false
+ })
- // if this was an embedded field then we already returned true up above
- return false
- })
+ if len(errs) > 0 {
+ return nil, errors.New(strings.Join(errs, "\n"))
+ }
- if len(errs) > 0 {
- return nil, errors.New(strings.Join(errs, "\n"))
+ // check that we don't have both positionals and subcommands
+ var hasPositional bool
+ for _, spec := range cmd.specs {
+ if spec.positional {
+ hasPositional = true
}
}
- if p.config.Program == "" {
- p.config.Program = "program"
- if len(os.Args) > 0 {
- p.config.Program = filepath.Base(os.Args[0])
- }
+ if hasPositional && len(cmd.subcommands) > 0 {
+ return nil, fmt.Errorf("%s cannot have both subcommands and positional arguments", dest)
}
- return &p, nil
+
+ return &cmd, nil
}
// Parse processes the given command line option, storing the results in the field
// of the structs from which NewParser was constructed
func (p *Parser) Parse(args []string) error {
- // If -h or --help were specified then print usage
- for _, arg := range args {
- if arg == "-h" || arg == "--help" {
- return ErrHelp
- }
- if arg == "--version" {
- return ErrVersion
- }
- if arg == "--" {
- break
+ err := p.process(args)
+ if err != nil {
+ // If -h or --help were specified then make sure help text supercedes other errors
+ for _, arg := range args {
+ if arg == "-h" || arg == "--help" {
+ return ErrHelp
+ }
+ if arg == "--" {
+ break
+ }
}
}
-
- // Process all command line arguments
- return process(p.specs, args)
+ return err
}
-// process goes through arguments one-by-one, parses them, and assigns the result to
-// the underlying struct field
-func process(specs []*spec, args []string) error {
- // track the options we have seen
- wasPresent := make(map[*spec]bool)
-
- // construct a map from --option to spec
- optionMap := make(map[string]*spec)
- for _, spec := range specs {
- if spec.positional {
- continue
- }
- if spec.long != "" {
- optionMap[spec.long] = spec
- }
- if spec.short != "" {
- optionMap[spec.short] = spec
- }
- }
-
- // deal with environment vars
+// process environment vars for the given arguments
+func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error {
for _, spec := range specs {
if spec.env == "" {
continue
@@ -298,7 +393,7 @@ func process(specs []*spec, args []string) error {
err,
)
}
- if err = setSlice(spec.dest, values, !spec.separate); err != nil {
+ if err = setSlice(p.val(spec.dest), values, !spec.separate); err != nil {
return fmt.Errorf(
"error processing environment variable %s with multiple values: %v",
spec.env,
@@ -306,13 +401,36 @@ func process(specs []*spec, args []string) error {
)
}
} else {
- if err := scalar.ParseValue(spec.dest, value); err != nil {
+ if err := scalar.ParseValue(p.val(spec.dest), value); err != nil {
return fmt.Errorf("error processing environment variable %s: %v", spec.env, err)
}
}
wasPresent[spec] = true
}
+ return nil
+}
+
+// process goes through arguments one-by-one, parses them, and assigns the result to
+// the underlying struct field
+func (p *Parser) process(args []string) error {
+ // track the options we have seen
+ wasPresent := make(map[*spec]bool)
+
+ // union of specs for the chain of subcommands encountered so far
+ curCmd := p.cmd
+ p.lastCmd = curCmd
+
+ // make a copy of the specs because we will add to this list each time we expand a subcommand
+ specs := make([]*spec, len(curCmd.specs))
+ copy(specs, curCmd.specs)
+
+ // deal with environment vars
+ err := p.captureEnvVars(specs, wasPresent)
+ if err != nil {
+ return err
+ }
+
// process each string from the command line
var allpositional bool
var positionals []string
@@ -326,10 +444,44 @@ func process(specs []*spec, args []string) error {
}
if !isFlag(arg) || allpositional {
- positionals = append(positionals, arg)
+ // each subcommand can have either subcommands or positionals, but not both
+ if len(curCmd.subcommands) == 0 {
+ positionals = append(positionals, arg)
+ continue
+ }
+
+ // if we have a subcommand then make sure it is valid for the current context
+ subcmd := findSubcommand(curCmd.subcommands, arg)
+ if subcmd == nil {
+ return fmt.Errorf("invalid subcommand: %s", arg)
+ }
+
+ // instantiate the field to point to a new struct
+ v := p.val(subcmd.dest)
+ v.Set(reflect.New(v.Type().Elem())) // we already checked that all subcommands are struct pointers
+
+ // add the new options to the set of allowed options
+ specs = append(specs, subcmd.specs...)
+
+ // capture environment vars for these new options
+ err := p.captureEnvVars(subcmd.specs, wasPresent)
+ if err != nil {
+ return err
+ }
+
+ curCmd = subcmd
+ p.lastCmd = curCmd
continue
}
+ // check for special --help and --version flags
+ switch arg {
+ case "-h", "--help":
+ return ErrHelp
+ case "--version":
+ return ErrVersion
+ }
+
// check for an equals sign, as in "--foo=bar"
var value string
opt := strings.TrimLeft(arg, "-")
@@ -338,9 +490,10 @@ func process(specs []*spec, args []string) error {
opt = opt[:pos]
}
- // lookup the spec for this option
- spec, ok := optionMap[opt]
- if !ok {
+ // lookup the spec for this option (note that the "specs" slice changes as
+ // we expand subcommands so it is better not to use a map)
+ spec := findOption(specs, opt)
+ if spec == nil {
return fmt.Errorf("unknown argument %s", arg)
}
wasPresent[spec] = true
@@ -359,7 +512,7 @@ func process(specs []*spec, args []string) error {
} else {
values = append(values, value)
}
- err := setSlice(spec.dest, values, !spec.separate)
+ err := setSlice(p.val(spec.dest), values, !spec.separate)
if err != nil {
return fmt.Errorf("error processing %s: %v", arg, err)
}
@@ -377,14 +530,14 @@ func process(specs []*spec, args []string) error {
if i+1 == len(args) {
return fmt.Errorf("missing value for %s", arg)
}
- if !nextIsNumeric(spec.dest.Type(), args[i+1]) && isFlag(args[i+1]) {
+ if !nextIsNumeric(spec.typ, args[i+1]) && isFlag(args[i+1]) {
return fmt.Errorf("missing value for %s", arg)
}
value = args[i+1]
i++
}
- err := scalar.ParseValue(spec.dest, value)
+ err := scalar.ParseValue(p.val(spec.dest), value)
if err != nil {
return fmt.Errorf("error processing %s: %v", arg, err)
}
@@ -400,13 +553,13 @@ func process(specs []*spec, args []string) error {
}
wasPresent[spec] = true
if spec.multiple {
- err := setSlice(spec.dest, positionals, true)
+ err := setSlice(p.val(spec.dest), positionals, true)
if err != nil {
return fmt.Errorf("error processing %s: %v", spec.long, err)
}
positionals = nil
} else {
- err := scalar.ParseValue(spec.dest, positionals[0])
+ err := scalar.ParseValue(p.val(spec.dest), positionals[0])
if err != nil {
return fmt.Errorf("error processing %s: %v", spec.long, err)
}
@@ -449,6 +602,30 @@ func isFlag(s string) bool {
return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != ""
}
+// val returns a reflect.Value corresponding to the current value for the
+// given path
+func (p *Parser) val(dest path) reflect.Value {
+ v := p.roots[dest.root]
+ for _, field := range dest.fields {
+ if v.Kind() == reflect.Ptr {
+ if v.IsNil() {
+ return reflect.Value{}
+ }
+ v = v.Elem()
+ }
+
+ v = v.FieldByName(field)
+ if !v.IsValid() {
+ // it is appropriate to panic here because this can only happen due to
+ // an internal bug in this library (since we construct the path ourselves
+ // by reflecting on the same struct)
+ panic(fmt.Errorf("error resolving path %v: %v has no field named %v",
+ dest.fields, v.Type(), field))
+ }
+ }
+ return v
+}
+
// parse a value as the appropriate type and store it in the struct
func setSlice(dest reflect.Value, values []string, trunc bool) error {
if !dest.CanSet() {
@@ -480,56 +657,25 @@ func setSlice(dest reflect.Value, values []string, trunc bool) error {
return nil
}
-// canParse returns true if the type can be parsed from a string
-func canParse(t reflect.Type) (parseable, boolean, multiple bool) {
- parseable = scalar.CanParse(t)
- boolean = isBoolean(t)
- if parseable {
- return
- }
-
- // Look inside pointer types
- if t.Kind() == reflect.Ptr {
- t = t.Elem()
- }
- // Look inside slice types
- if t.Kind() == reflect.Slice {
- multiple = true
- t = t.Elem()
- }
-
- parseable = scalar.CanParse(t)
- boolean = isBoolean(t)
- if parseable {
- return
- }
-
- // Look inside pointer types (again, in case of []*Type)
- if t.Kind() == reflect.Ptr {
- t = t.Elem()
- }
-
- parseable = scalar.CanParse(t)
- boolean = isBoolean(t)
- if parseable {
- return
+// findOption finds an option from its name, or returns null if no spec is found
+func findOption(specs []*spec, name string) *spec {
+ for _, spec := range specs {
+ if spec.positional {
+ continue
+ }
+ if spec.long == name || spec.short == name {
+ return spec
+ }
}
-
- return false, false, false
+ return nil
}
-var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
-
-// isBoolean returns true if the type can be parsed from a single string
-func isBoolean(t reflect.Type) bool {
- switch {
- case t.Implements(textUnmarshalerType):
- return false
- case t.Kind() == reflect.Bool:
- return true
- case t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Bool:
- return true
- default:
- return false
+// findSubcommand finds a subcommand using its name, or returns null if no subcommand is found
+func findSubcommand(cmds []*command, name string) *command {
+ for _, cmd := range cmds {
+ if cmd.name == name {
+ return cmd
+ }
}
+ return nil
}
diff --git a/parse_test.go b/parse_test.go
index 9aad2e3..882564e 100644
--- a/parse_test.go
+++ b/parse_test.go
@@ -19,15 +19,20 @@ func setenv(t *testing.T, name, val string) {
}
func parse(cmdline string, dest interface{}) error {
+ _, err := pparse(cmdline, dest)
+ return err
+}
+
+func pparse(cmdline string, dest interface{}) (*Parser, error) {
p, err := NewParser(Config{}, dest)
if err != nil {
- return err
+ return nil, err
}
var parts []string
if len(cmdline) > 0 {
parts = strings.Split(cmdline, " ")
}
- return p.Parse(parts)
+ return p, p.Parse(parts)
}
func TestString(t *testing.T) {
@@ -371,7 +376,7 @@ func TestNonsenseKey(t *testing.T) {
assert.Error(t, err)
}
-func TestMissingValue(t *testing.T) {
+func TestMissingValueAtEnd(t *testing.T) {
var args struct {
Foo string
}
@@ -379,6 +384,24 @@ func TestMissingValue(t *testing.T) {
assert.Error(t, err)
}
+func TestMissingValueInMIddle(t *testing.T) {
+ var args struct {
+ Foo string
+ Bar string
+ }
+ err := parse("--foo --bar=abc", &args)
+ assert.Error(t, err)
+}
+
+func TestNegativeValue(t *testing.T) {
+ var args struct {
+ Foo int
+ }
+ err := parse("--foo -123", &args)
+ require.NoError(t, err)
+ assert.Equal(t, -123, args.Foo)
+}
+
func TestInvalidInt(t *testing.T) {
var args struct {
Foo int
@@ -462,11 +485,10 @@ func TestPanicOnNonPointer(t *testing.T) {
})
}
-func TestPanicOnNonStruct(t *testing.T) {
+func TestErrorOnNonStruct(t *testing.T) {
var args string
- assert.Panics(t, func() {
- _ = parse("", &args)
- })
+ err := parse("", &args)
+ assert.Error(t, err)
}
func TestUnsupportedType(t *testing.T) {
@@ -540,6 +562,15 @@ func TestEnvironmentVariable(t *testing.T) {
assert.Equal(t, "bar", args.Foo)
}
+func TestEnvironmentVariableNotPresent(t *testing.T) {
+ var args struct {
+ NotPresent string `arg:"env"`
+ }
+ os.Args = []string{"example"}
+ MustParse(&args)
+ assert.Equal(t, "", args.NotPresent)
+}
+
func TestEnvironmentVariableOverrideName(t *testing.T) {
var args struct {
Foo string `arg:"env:BAZ"`
@@ -584,7 +615,7 @@ func TestEnvironmentVariableSliceArgumentString(t *testing.T) {
var args struct {
Foo []string `arg:"env"`
}
- setenv(t, "FOO", "bar,\"baz, qux\"")
+ setenv(t, "FOO", `bar,"baz, qux"`)
MustParse(&args)
assert.Equal(t, []string{"bar", "baz, qux"}, args.Foo)
}
@@ -846,6 +877,28 @@ func TestEmbedded(t *testing.T) {
assert.Equal(t, true, args.Z)
}
+func TestEmbeddedPtr(t *testing.T) {
+ // embedded pointer fields are not supported so this should return an error
+ var args struct {
+ *A
+ }
+ err := parse("--x=hello", &args)
+ require.Error(t, err)
+}
+
+func TestEmbeddedPtrIgnored(t *testing.T) {
+ // embedded pointer fields are not normally supported but here
+ // we explicitly exclude it so the non-nil embedded structs
+ // should work as expected
+ var args struct {
+ *A `arg:"-"`
+ B
+ }
+ err := parse("--y=321", &args)
+ require.NoError(t, err)
+ assert.Equal(t, 321, args.Y)
+}
+
func TestEmptyArgs(t *testing.T) {
origArgs := os.Args
@@ -985,3 +1038,10 @@ func TestReuseParser(t *testing.T) {
err = p.Parse([]string{})
assert.Error(t, err)
}
+
+func TestVersion(t *testing.T) {
+ var args struct{}
+ err := parse("--version", &args)
+ assert.Equal(t, ErrVersion, err)
+
+}
diff --git a/reflect.go b/reflect.go
new file mode 100644
index 0000000..e113583
--- /dev/null
+++ b/reflect.go
@@ -0,0 +1,62 @@
+package arg
+
+import (
+ "encoding"
+ "reflect"
+
+ scalar "github.com/alexflint/go-scalar"
+)
+
+var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
+
+// canParse returns true if the type can be parsed from a string
+func canParse(t reflect.Type) (parseable, boolean, multiple bool) {
+ parseable = scalar.CanParse(t)
+ boolean = isBoolean(t)
+ if parseable {
+ return
+ }
+
+ // Look inside pointer types
+ if t.Kind() == reflect.Ptr {
+ t = t.Elem()
+ }
+ // Look inside slice types
+ if t.Kind() == reflect.Slice {
+ multiple = true
+ t = t.Elem()
+ }
+
+ parseable = scalar.CanParse(t)
+ boolean = isBoolean(t)
+ if parseable {
+ return
+ }
+
+ // Look inside pointer types (again, in case of []*Type)
+ if t.Kind() == reflect.Ptr {
+ t = t.Elem()
+ }
+
+ parseable = scalar.CanParse(t)
+ boolean = isBoolean(t)
+ if parseable {
+ return
+ }
+
+ return false, false, false
+}
+
+// isBoolean returns true if the type can be parsed from a single string
+func isBoolean(t reflect.Type) bool {
+ switch {
+ case t.Implements(textUnmarshalerType):
+ return false
+ case t.Kind() == reflect.Bool:
+ return true
+ case t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Bool:
+ return true
+ default:
+ return false
+ }
+}
diff --git a/reflect_test.go b/reflect_test.go
new file mode 100644
index 0000000..47e68b5
--- /dev/null
+++ b/reflect_test.go
@@ -0,0 +1,55 @@
+package arg
+
+import (
+ "reflect"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func assertCanParse(t *testing.T, typ reflect.Type, parseable, boolean, multiple bool) {
+ p, b, m := canParse(typ)
+ assert.Equal(t, parseable, p, "expected %v to have parseable=%v but was %v", typ, parseable, p)
+ assert.Equal(t, boolean, b, "expected %v to have boolean=%v but was %v", typ, boolean, b)
+ assert.Equal(t, multiple, m, "expected %v to have multiple=%v but was %v", typ, multiple, m)
+}
+
+func TestCanParse(t *testing.T) {
+ var b bool
+ var i int
+ var s string
+ var f float64
+ var bs []bool
+ var is []int
+
+ assertCanParse(t, reflect.TypeOf(b), true, true, false)
+ assertCanParse(t, reflect.TypeOf(i), true, false, false)
+ assertCanParse(t, reflect.TypeOf(s), true, false, false)
+ assertCanParse(t, reflect.TypeOf(f), true, false, false)
+
+ assertCanParse(t, reflect.TypeOf(&b), true, true, false)
+ assertCanParse(t, reflect.TypeOf(&s), true, false, false)
+ assertCanParse(t, reflect.TypeOf(&i), true, false, false)
+ assertCanParse(t, reflect.TypeOf(&f), true, false, false)
+
+ assertCanParse(t, reflect.TypeOf(bs), true, true, true)
+ assertCanParse(t, reflect.TypeOf(&bs), true, true, true)
+
+ assertCanParse(t, reflect.TypeOf(is), true, false, true)
+ assertCanParse(t, reflect.TypeOf(&is), true, false, true)
+}
+
+type implementsTextUnmarshaler struct{}
+
+func (*implementsTextUnmarshaler) UnmarshalText(text []byte) error {
+ return nil
+}
+
+func TestCanParseTextUnmarshaler(t *testing.T) {
+ var u implementsTextUnmarshaler
+ var su []implementsTextUnmarshaler
+ assertCanParse(t, reflect.TypeOf(u), true, false, false)
+ assertCanParse(t, reflect.TypeOf(&u), true, false, false)
+ assertCanParse(t, reflect.TypeOf(su), true, false, true)
+ assertCanParse(t, reflect.TypeOf(&su), true, false, true)
+}
diff --git a/subcommand.go b/subcommand.go
new file mode 100644
index 0000000..dff732c
--- /dev/null
+++ b/subcommand.go
@@ -0,0 +1,37 @@
+package arg
+
+// Subcommand returns the user struct for the subcommand selected by
+// the command line arguments most recently processed by the parser.
+// The return value is always a pointer to a struct. If no subcommand
+// was specified then it returns the top-level arguments struct. If
+// no command line arguments have been processed by this parser then it
+// returns nil.
+func (p *Parser) Subcommand() interface{} {
+ if p.lastCmd == nil || p.lastCmd.parent == nil {
+ return nil
+ }
+ return p.val(p.lastCmd.dest).Interface()
+}
+
+// SubcommandNames returns the sequence of subcommands specified by the
+// user. If no subcommands were given then it returns an empty slice.
+func (p *Parser) SubcommandNames() []string {
+ if p.lastCmd == nil {
+ return nil
+ }
+
+ // make a list of ancestor commands
+ var ancestors []string
+ cur := p.lastCmd
+ for cur.parent != nil { // we want to exclude the root
+ ancestors = append(ancestors, cur.name)
+ cur = cur.parent
+ }
+
+ // reverse the list
+ out := make([]string, len(ancestors))
+ for i := 0; i < len(ancestors); i++ {
+ out[i] = ancestors[len(ancestors)-i-1]
+ }
+ return out
+}
diff --git a/subcommand_test.go b/subcommand_test.go
new file mode 100644
index 0000000..c34ab01
--- /dev/null
+++ b/subcommand_test.go
@@ -0,0 +1,355 @@
+package arg
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// This file contains tests for parse.go but I decided to put them here
+// since that file is getting large
+
+func TestSubcommandNotAPointer(t *testing.T) {
+ var args struct {
+ A string `arg:"subcommand"`
+ }
+ _, err := NewParser(Config{}, &args)
+ assert.Error(t, err)
+}
+
+func TestSubcommandNotAPointerToStruct(t *testing.T) {
+ var args struct {
+ A struct{} `arg:"subcommand"`
+ }
+ _, err := NewParser(Config{}, &args)
+ assert.Error(t, err)
+}
+
+func TestPositionalAndSubcommandNotAllowed(t *testing.T) {
+ var args struct {
+ A string `arg:"positional"`
+ B *struct{} `arg:"subcommand"`
+ }
+ _, err := NewParser(Config{}, &args)
+ assert.Error(t, err)
+}
+
+func TestMinimalSubcommand(t *testing.T) {
+ type listCmd struct {
+ }
+ var args struct {
+ List *listCmd `arg:"subcommand"`
+ }
+ p, err := pparse("list", &args)
+ require.NoError(t, err)
+ assert.NotNil(t, args.List)
+ assert.Equal(t, args.List, p.Subcommand())
+ assert.Equal(t, []string{"list"}, p.SubcommandNames())
+}
+
+func TestNoSuchSubcommand(t *testing.T) {
+ type listCmd struct {
+ }
+ var args struct {
+ List *listCmd `arg:"subcommand"`
+ }
+ _, err := pparse("invalid", &args)
+ assert.Error(t, err)
+}
+
+func TestNamedSubcommand(t *testing.T) {
+ type listCmd struct {
+ }
+ var args struct {
+ List *listCmd `arg:"subcommand:ls"`
+ }
+ p, err := pparse("ls", &args)
+ require.NoError(t, err)
+ assert.NotNil(t, args.List)
+ assert.Equal(t, args.List, p.Subcommand())
+ assert.Equal(t, []string{"ls"}, p.SubcommandNames())
+}
+
+func TestEmptySubcommand(t *testing.T) {
+ type listCmd struct {
+ }
+ var args struct {
+ List *listCmd `arg:"subcommand"`
+ }
+ p, err := pparse("", &args)
+ require.NoError(t, err)
+ assert.Nil(t, args.List)
+ assert.Nil(t, p.Subcommand())
+ assert.Empty(t, p.SubcommandNames())
+}
+
+func TestTwoSubcommands(t *testing.T) {
+ type getCmd struct {
+ }
+ type listCmd struct {
+ }
+ var args struct {
+ Get *getCmd `arg:"subcommand"`
+ List *listCmd `arg:"subcommand"`
+ }
+ p, err := pparse("list", &args)
+ require.NoError(t, err)
+ assert.Nil(t, args.Get)
+ assert.NotNil(t, args.List)
+ assert.Equal(t, args.List, p.Subcommand())
+ assert.Equal(t, []string{"list"}, p.SubcommandNames())
+}
+
+func TestSubcommandsWithOptions(t *testing.T) {
+ type getCmd struct {
+ Name string
+ }
+ type listCmd struct {
+ Limit int
+ }
+ type cmd struct {
+ Verbose bool
+ Get *getCmd `arg:"subcommand"`
+ List *listCmd `arg:"subcommand"`
+ }
+
+ {
+ var args cmd
+ err := parse("list", &args)
+ require.NoError(t, err)
+ assert.Nil(t, args.Get)
+ assert.NotNil(t, args.List)
+ }
+
+ {
+ var args cmd
+ err := parse("list --limit 3", &args)
+ require.NoError(t, err)
+ assert.Nil(t, args.Get)
+ assert.NotNil(t, args.List)
+ assert.Equal(t, args.List.Limit, 3)
+ }
+
+ {
+ var args cmd
+ err := parse("list --limit 3 --verbose", &args)
+ require.NoError(t, err)
+ assert.Nil(t, args.Get)
+ assert.NotNil(t, args.List)
+ assert.Equal(t, args.List.Limit, 3)
+ assert.True(t, args.Verbose)
+ }
+
+ {
+ var args cmd
+ err := parse("list --verbose --limit 3", &args)
+ require.NoError(t, err)
+ assert.Nil(t, args.Get)
+ assert.NotNil(t, args.List)
+ assert.Equal(t, args.List.Limit, 3)
+ assert.True(t, args.Verbose)
+ }
+
+ {
+ var args cmd
+ err := parse("--verbose list --limit 3", &args)
+ require.NoError(t, err)
+ assert.Nil(t, args.Get)
+ assert.NotNil(t, args.List)
+ assert.Equal(t, args.List.Limit, 3)
+ assert.True(t, args.Verbose)
+ }
+
+ {
+ var args cmd
+ err := parse("get", &args)
+ require.NoError(t, err)
+ assert.NotNil(t, args.Get)
+ assert.Nil(t, args.List)
+ }
+
+ {
+ var args cmd
+ err := parse("get --name test", &args)
+ require.NoError(t, err)
+ assert.NotNil(t, args.Get)
+ assert.Nil(t, args.List)
+ assert.Equal(t, args.Get.Name, "test")
+ }
+}
+
+func TestNestedSubcommands(t *testing.T) {
+ type child struct{}
+ type parent struct {
+ Child *child `arg:"subcommand"`
+ }
+ type grandparent struct {
+ Parent *parent `arg:"subcommand"`
+ }
+ type root struct {
+ Grandparent *grandparent `arg:"subcommand"`
+ }
+
+ {
+ var args root
+ p, err := pparse("grandparent parent child", &args)
+ require.NoError(t, err)
+ require.NotNil(t, args.Grandparent)
+ require.NotNil(t, args.Grandparent.Parent)
+ require.NotNil(t, args.Grandparent.Parent.Child)
+ assert.Equal(t, args.Grandparent.Parent.Child, p.Subcommand())
+ assert.Equal(t, []string{"grandparent", "parent", "child"}, p.SubcommandNames())
+ }
+
+ {
+ var args root
+ p, err := pparse("grandparent parent", &args)
+ require.NoError(t, err)
+ require.NotNil(t, args.Grandparent)
+ require.NotNil(t, args.Grandparent.Parent)
+ require.Nil(t, args.Grandparent.Parent.Child)
+ assert.Equal(t, args.Grandparent.Parent, p.Subcommand())
+ assert.Equal(t, []string{"grandparent", "parent"}, p.SubcommandNames())
+ }
+
+ {
+ var args root
+ p, err := pparse("grandparent", &args)
+ require.NoError(t, err)
+ require.NotNil(t, args.Grandparent)
+ require.Nil(t, args.Grandparent.Parent)
+ assert.Equal(t, args.Grandparent, p.Subcommand())
+ assert.Equal(t, []string{"grandparent"}, p.SubcommandNames())
+ }
+
+ {
+ var args root
+ p, err := pparse("", &args)
+ require.NoError(t, err)
+ require.Nil(t, args.Grandparent)
+ assert.Nil(t, p.Subcommand())
+ assert.Empty(t, p.SubcommandNames())
+ }
+}
+
+func TestSubcommandsWithPositionals(t *testing.T) {
+ type listCmd struct {
+ Pattern string `arg:"positional"`
+ }
+ type cmd struct {
+ Format string
+ List *listCmd `arg:"subcommand"`
+ }
+
+ {
+ var args cmd
+ err := parse("list", &args)
+ require.NoError(t, err)
+ assert.NotNil(t, args.List)
+ assert.Equal(t, "", args.List.Pattern)
+ }
+
+ {
+ var args cmd
+ err := parse("list --format json", &args)
+ require.NoError(t, err)
+ assert.NotNil(t, args.List)
+ assert.Equal(t, "", args.List.Pattern)
+ assert.Equal(t, "json", args.Format)
+ }
+
+ {
+ var args cmd
+ err := parse("list somepattern", &args)
+ require.NoError(t, err)
+ assert.NotNil(t, args.List)
+ assert.Equal(t, "somepattern", args.List.Pattern)
+ }
+
+ {
+ var args cmd
+ err := parse("list somepattern --format json", &args)
+ require.NoError(t, err)
+ assert.NotNil(t, args.List)
+ assert.Equal(t, "somepattern", args.List.Pattern)
+ assert.Equal(t, "json", args.Format)
+ }
+
+ {
+ var args cmd
+ err := parse("list --format json somepattern", &args)
+ require.NoError(t, err)
+ assert.NotNil(t, args.List)
+ assert.Equal(t, "somepattern", args.List.Pattern)
+ assert.Equal(t, "json", args.Format)
+ }
+
+ {
+ var args cmd
+ err := parse("--format json list somepattern", &args)
+ require.NoError(t, err)
+ assert.NotNil(t, args.List)
+ assert.Equal(t, "somepattern", args.List.Pattern)
+ assert.Equal(t, "json", args.Format)
+ }
+
+ {
+ var args cmd
+ err := parse("--format json", &args)
+ require.NoError(t, err)
+ assert.Nil(t, args.List)
+ assert.Equal(t, "json", args.Format)
+ }
+}
+func TestSubcommandsWithMultiplePositionals(t *testing.T) {
+ type getCmd struct {
+ Items []string `arg:"positional"`
+ }
+ type cmd struct {
+ Limit int
+ Get *getCmd `arg:"subcommand"`
+ }
+
+ {
+ var args cmd
+ err := parse("get", &args)
+ require.NoError(t, err)
+ assert.NotNil(t, args.Get)
+ assert.Empty(t, args.Get.Items)
+ }
+
+ {
+ var args cmd
+ err := parse("get --limit 5", &args)
+ require.NoError(t, err)
+ assert.NotNil(t, args.Get)
+ assert.Empty(t, args.Get.Items)
+ assert.Equal(t, 5, args.Limit)
+ }
+
+ {
+ var args cmd
+ err := parse("get item1", &args)
+ require.NoError(t, err)
+ assert.NotNil(t, args.Get)
+ assert.Equal(t, []string{"item1"}, args.Get.Items)
+ }
+
+ {
+ var args cmd
+ err := parse("get item1 item2 item3", &args)
+ require.NoError(t, err)
+ assert.NotNil(t, args.Get)
+ assert.Equal(t, []string{"item1", "item2", "item3"}, args.Get.Items)
+ }
+
+ {
+ var args cmd
+ err := parse("get item1 --limit 5 item2", &args)
+ require.NoError(t, err)
+ assert.NotNil(t, args.Get)
+ assert.Equal(t, []string{"item1", "item2"}, args.Get.Items)
+ assert.Equal(t, 5, args.Limit)
+ }
+}
diff --git a/usage.go b/usage.go
index cfac563..33b6184 100644
--- a/usage.go
+++ b/usage.go
@@ -12,17 +12,30 @@ import (
// the width of the left column
const colWidth = 25
+// to allow monkey patching in tests
+var stderr = os.Stderr
+
// Fail prints usage information to stderr and exits with non-zero status
func (p *Parser) Fail(msg string) {
- p.WriteUsage(os.Stderr)
- fmt.Fprintln(os.Stderr, "error:", msg)
- os.Exit(-1)
+ p.failWithCommand(msg, p.cmd)
+}
+
+// failWithCommand prints usage information for the given subcommand to stderr and exits with non-zero status
+func (p *Parser) failWithCommand(msg string, cmd *command) {
+ p.writeUsageForCommand(stderr, cmd)
+ fmt.Fprintln(stderr, "error:", msg)
+ osExit(-1)
}
// WriteUsage writes usage information to the given writer
func (p *Parser) WriteUsage(w io.Writer) {
+ p.writeUsageForCommand(w, p.cmd)
+}
+
+// writeUsageForCommand writes usage information for the given subcommand
+func (p *Parser) writeUsageForCommand(w io.Writer, cmd *command) {
var positionals, options []*spec
- for _, spec := range p.specs {
+ for _, spec := range cmd.specs {
if spec.positional {
positionals = append(positionals, spec)
} else {
@@ -34,7 +47,19 @@ func (p *Parser) WriteUsage(w io.Writer) {
fmt.Fprintln(w, p.version)
}
- fmt.Fprintf(w, "Usage: %s", p.config.Program)
+ // make a list of ancestor commands so that we print with full context
+ var ancestors []string
+ ancestor := cmd
+ for ancestor != nil {
+ ancestors = append(ancestors, ancestor.name)
+ ancestor = ancestor.parent
+ }
+
+ // print the beginning of the usage string
+ fmt.Fprint(w, "Usage:")
+ for i := len(ancestors) - 1; i >= 0; i-- {
+ fmt.Fprint(w, " "+ancestors[i])
+ }
// write the option component of the usage message
for _, spec := range options {
@@ -69,10 +94,32 @@ func (p *Parser) WriteUsage(w io.Writer) {
fmt.Fprint(w, "\n")
}
+func printTwoCols(w io.Writer, left, help string, defaultVal *string) {
+ lhs := " " + left
+ fmt.Fprint(w, lhs)
+ if help != "" {
+ if len(lhs)+2 < colWidth {
+ fmt.Fprint(w, strings.Repeat(" ", colWidth-len(lhs)))
+ } else {
+ fmt.Fprint(w, "\n"+strings.Repeat(" ", colWidth))
+ }
+ fmt.Fprint(w, help)
+ }
+ if defaultVal != nil {
+ fmt.Fprintf(w, " [default: %s]", *defaultVal)
+ }
+ fmt.Fprint(w, "\n")
+}
+
// WriteHelp writes the usage string followed by the full help string for each option
func (p *Parser) WriteHelp(w io.Writer) {
+ p.writeHelpForCommand(w, p.cmd)
+}
+
+// writeHelp writes the usage string for the given subcommand
+func (p *Parser) writeHelpForCommand(w io.Writer, cmd *command) {
var positionals, options []*spec
- for _, spec := range p.specs {
+ for _, spec := range cmd.specs {
if spec.positional {
positionals = append(positionals, spec)
} else {
@@ -83,70 +130,74 @@ func (p *Parser) WriteHelp(w io.Writer) {
if p.description != "" {
fmt.Fprintln(w, p.description)
}
- p.WriteUsage(w)
+ p.writeUsageForCommand(w, cmd)
// write the list of positionals
if len(positionals) > 0 {
fmt.Fprint(w, "\nPositional arguments:\n")
for _, spec := range positionals {
- left := " " + strings.ToUpper(spec.long)
- fmt.Fprint(w, left)
- if spec.help != "" {
- if len(left)+2 < colWidth {
- fmt.Fprint(w, strings.Repeat(" ", colWidth-len(left)))
- } else {
- fmt.Fprint(w, "\n"+strings.Repeat(" ", colWidth))
- }
- fmt.Fprint(w, spec.help)
- }
- fmt.Fprint(w, "\n")
+ printTwoCols(w, strings.ToUpper(spec.long), spec.help, nil)
}
}
// write the list of options
fmt.Fprint(w, "\nOptions:\n")
for _, spec := range options {
- printOption(w, spec)
+ p.printOption(w, spec)
}
// write the list of built in options
- printOption(w, &spec{boolean: true, long: "help", short: "h", help: "display this help and exit"})
+ p.printOption(w, &spec{
+ boolean: true,
+ long: "help",
+ short: "h",
+ help: "display this help and exit",
+ })
if p.version != "" {
- printOption(w, &spec{boolean: true, long: "version", help: "display version and exit"})
+ p.printOption(w, &spec{
+ boolean: true,
+ long: "version",
+ help: "display version and exit",
+ })
+ }
+
+ // write the list of subcommands
+ if len(cmd.subcommands) > 0 {
+ fmt.Fprint(w, "\nCommands:\n")
+ for _, subcmd := range cmd.subcommands {
+ printTwoCols(w, subcmd.name, subcmd.help, nil)
+ }
}
}
-func printOption(w io.Writer, spec *spec) {
- left := " " + synopsis(spec, "--"+spec.long)
+func (p *Parser) printOption(w io.Writer, spec *spec) {
+ left := synopsis(spec, "--"+spec.long)
if spec.short != "" {
left += ", " + synopsis(spec, "-"+spec.short)
}
- fmt.Fprint(w, left)
- if spec.help != "" {
- if len(left)+2 < colWidth {
- fmt.Fprint(w, strings.Repeat(" ", colWidth-len(left)))
- } else {
- fmt.Fprint(w, "\n"+strings.Repeat(" ", colWidth))
- }
- fmt.Fprint(w, spec.help)
- }
+
// If spec.dest is not the zero value then a default value has been added.
- v := spec.dest
+ var v reflect.Value
+ if len(spec.dest.fields) > 0 {
+ v = p.val(spec.dest)
+ }
+
+ var defaultVal *string
if v.IsValid() {
z := reflect.Zero(v.Type())
if (v.Type().Comparable() && z.Type().Comparable() && v.Interface() != z.Interface()) || v.Kind() == reflect.Slice && !v.IsNil() {
if scalar, ok := v.Interface().(encoding.TextMarshaler); ok {
if value, err := scalar.MarshalText(); err != nil {
- fmt.Fprintf(w, " [default: error: %v]", err)
+ defaultVal = ptrTo(fmt.Sprintf("error: %v", err))
} else {
- fmt.Fprintf(w, " [default: %v]", string(value))
+ defaultVal = ptrTo(fmt.Sprintf("%v", string(value)))
}
} else {
- fmt.Fprintf(w, " [default: %v]", v)
+ defaultVal = ptrTo(fmt.Sprintf("%v", v))
}
}
}
- fmt.Fprint(w, "\n")
+ printTwoCols(w, left, spec.help, defaultVal)
}
func synopsis(spec *spec, form string) string {
@@ -155,3 +206,7 @@ func synopsis(spec *spec, form string) string {
}
return form + " " + strings.ToUpper(spec.long)
}
+
+func ptrTo(s string) *string {
+ return &s
+}