diff options
Diffstat (limited to 'parse.go')
| -rw-r--r-- | parse.go | 518 |
1 files changed, 332 insertions, 186 deletions
@@ -1,7 +1,6 @@ package arg import ( - "encoding" "encoding/csv" "errors" "fmt" @@ -16,9 +15,36 @@ import ( // to enable monkey-patching during tests var osExit = os.Exit +// path represents a sequence of steps to find the output location for an +// argument or subcommand in the final destination struct +type path struct { + root int // index of the destination struct + fields []string // sequence of struct field names to traverse +} + +// String gets a string representation of the given path +func (p path) String() string { + if len(p.fields) == 0 { + return "args" + } + return "args." + strings.Join(p.fields, ".") +} + +// Child gets a new path representing a child of this path. +func (p path) Child(child string) path { + // copy the entire slice of fields to avoid possible slice overwrite + subfields := make([]string, len(p.fields)+1) + copy(subfields, append(p.fields, child)) + return path{ + root: p.root, + fields: subfields, + } +} + // spec represents a command line option type spec struct { - dest reflect.Value + dest path + typ reflect.Type long string short string multiple bool @@ -30,6 +56,16 @@ type spec struct { boolean bool } +// command represents a named subcommand, or the top-level command +type command struct { + name string + help string + dest path + specs []*spec + subcommands []*command + parent *command +} + // ErrHelp indicates that -h or --help were provided var ErrHelp = errors.New("help requested by user") @@ -42,18 +78,19 @@ func MustParse(dest ...interface{}) *Parser { if err != nil { fmt.Println(err) osExit(-1) + return nil // just in case osExit was monkey-patched } err = p.Parse(flags()) switch { case err == ErrHelp: - p.WriteHelp(os.Stdout) + p.writeHelpForCommand(os.Stdout, p.lastCmd) osExit(0) case err == ErrVersion: fmt.Println(p.version) osExit(0) case err != nil: - p.Fail(err.Error()) + p.failWithCommand(err.Error(), p.lastCmd) } return p @@ -83,10 +120,14 @@ type Config struct { // Parser represents a set of command line options with destination values type Parser struct { - specs []*spec + cmd *command + roots []reflect.Value config Config version string description string + + // the following fields change curing processing of command line arguments + lastCmd *command } // Versioned is the interface that the destination struct should implement to @@ -106,66 +147,180 @@ type Described interface { } // walkFields calls a function for each field of a struct, recursively expanding struct fields. -func walkFields(v reflect.Value, visit func(field reflect.StructField, val reflect.Value, owner reflect.Type) bool) { - t := v.Type() +func walkFields(t reflect.Type, visit func(field reflect.StructField, owner reflect.Type) bool) { for i := 0; i < t.NumField(); i++ { field := t.Field(i) - val := v.Field(i) - expand := visit(field, val, t) + expand := visit(field, t) if expand && field.Type.Kind() == reflect.Struct { - walkFields(val, visit) + walkFields(field.Type, visit) } } } // NewParser constructs a parser from a list of destination structs func NewParser(config Config, dests ...interface{}) (*Parser, error) { + // first pick a name for the command for use in the usage text + var name string + switch { + case config.Program != "": + name = config.Program + case len(os.Args) > 0: + name = filepath.Base(os.Args[0]) + default: + name = "program" + } + + // construct a parser p := Parser{ + cmd: &command{name: name}, config: config, } + + // make a list of roots for _, dest := range dests { + p.roots = append(p.roots, reflect.ValueOf(dest)) + } + + // process each of the destination values + for i, dest := range dests { + t := reflect.TypeOf(dest) + if t.Kind() != reflect.Ptr { + panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", t)) + } + + cmd, err := cmdFromStruct(name, path{root: i}, t) + if err != nil { + return nil, err + } + p.cmd.specs = append(p.cmd.specs, cmd.specs...) + p.cmd.subcommands = append(p.cmd.subcommands, cmd.subcommands...) + if dest, ok := dest.(Versioned); ok { p.version = dest.Version() } if dest, ok := dest.(Described); ok { p.description = dest.Description() } - v := reflect.ValueOf(dest) - if v.Kind() != reflect.Ptr { - panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", v.Type())) + } + + return &p, nil +} + +func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { + // commands can only be created from pointers to structs + if t.Kind() != reflect.Ptr { + return nil, fmt.Errorf("subcommands must be pointers to structs but %s is a %s", + dest, t.Kind()) + } + + t = t.Elem() + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("subcommands must be pointers to structs but %s is a pointer to %s", + dest, t.Kind()) + } + + cmd := command{ + name: name, + dest: dest, + } + + var errs []string + walkFields(t, func(field reflect.StructField, t reflect.Type) bool { + // Check for the ignore switch in the tag + tag := field.Tag.Get("arg") + if tag == "-" { + return false } - v = v.Elem() - if v.Kind() != reflect.Struct { - panic(fmt.Sprintf("%T is not a struct pointer", dest)) + + // If this is an embedded struct then recurse into its fields + if field.Anonymous && field.Type.Kind() == reflect.Struct { + return true } - var errs []string - walkFields(v, func(field reflect.StructField, val reflect.Value, t reflect.Type) bool { - // Check for the ignore switch in the tag - tag := field.Tag.Get("arg") - if tag == "-" { - return false - } + // duplicate the entire path to avoid slice overwrites + subdest := dest.Child(field.Name) + spec := spec{ + dest: subdest, + long: strings.ToLower(field.Name), + typ: field.Type, + } - // If this is an embedded struct then recurse into its fields - if field.Anonymous && field.Type.Kind() == reflect.Struct { - return true - } + help, exists := field.Tag.Lookup("help") + if exists { + spec.help = help + } - spec := spec{ - long: strings.ToLower(field.Name), - dest: val, - } + // Look at the tag + var isSubcommand bool // tracks whether this field is a subcommand + if tag != "" { + for _, key := range strings.Split(tag, ",") { + key = strings.TrimLeft(key, " ") + var value string + if pos := strings.Index(key, ":"); pos != -1 { + value = key[pos+1:] + key = key[:pos] + } - help, exists := field.Tag.Lookup("help") - if exists { - spec.help = help + switch { + case strings.HasPrefix(key, "---"): + errs = append(errs, fmt.Sprintf("%s.%s: too many hyphens", t.Name(), field.Name)) + case strings.HasPrefix(key, "--"): + spec.long = key[2:] + case strings.HasPrefix(key, "-"): + if len(key) != 2 { + errs = append(errs, fmt.Sprintf("%s.%s: short arguments must be one character only", + t.Name(), field.Name)) + return false + } + spec.short = key[1:] + case key == "required": + spec.required = true + case key == "positional": + spec.positional = true + case key == "separate": + spec.separate = true + case key == "help": // deprecated + spec.help = value + case key == "env": + // Use override name if provided + if value != "" { + spec.env = value + } else { + spec.env = strings.ToUpper(field.Name) + } + case key == "subcommand": + // decide on a name for the subcommand + cmdname := value + if cmdname == "" { + cmdname = strings.ToLower(field.Name) + } + + // parse the subcommand recursively + subcmd, err := cmdFromStruct(cmdname, subdest, field.Type) + if err != nil { + errs = append(errs, err.Error()) + return false + } + + subcmd.parent = &cmd + subcmd.help = field.Tag.Get("help") + + cmd.subcommands = append(cmd.subcommands, subcmd) + isSubcommand = true + default: + errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag)) + return false + } } + } + + // Check whether this field is supported. It's good to do this here rather than + // wait until ParseValue because it means that a program with invalid argument + // fields will always fail regardless of whether the arguments it received + // exercised those fields. + if !isSubcommand { + cmd.specs = append(cmd.specs, &spec) - // Check whether this field is supported. It's good to do this here rather than - // wait until ParseValue because it means that a program with invalid argument - // fields will always fail regardless of whether the arguments it received - // exercised those fields. var parseable bool parseable, spec.boolean, spec.multiple = canParse(field.Type) if !parseable { @@ -173,110 +328,50 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) { t.Name(), field.Name, field.Type.String())) return false } + } - // Look at the tag - if tag != "" { - for _, key := range strings.Split(tag, ",") { - key = strings.TrimLeft(key, " ") - var value string - if pos := strings.Index(key, ":"); pos != -1 { - value = key[pos+1:] - key = key[:pos] - } - - switch { - case strings.HasPrefix(key, "---"): - errs = append(errs, fmt.Sprintf("%s.%s: too many hyphens", t.Name(), field.Name)) - case strings.HasPrefix(key, "--"): - spec.long = key[2:] - case strings.HasPrefix(key, "-"): - if len(key) != 2 { - errs = append(errs, fmt.Sprintf("%s.%s: short arguments must be one character only", - t.Name(), field.Name)) - return false - } - spec.short = key[1:] - case key == "required": - spec.required = true - case key == "positional": - spec.positional = true - case key == "separate": - spec.separate = true - case key == "help": // deprecated - spec.help = value - case key == "env": - // Use override name if provided - if value != "" { - spec.env = value - } else { - spec.env = strings.ToUpper(field.Name) - } - default: - errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag)) - return false - } - } - } - p.specs = append(p.specs, &spec) + // if this was an embedded field then we already returned true up above + return false + }) - // if this was an embedded field then we already returned true up above - return false - }) + if len(errs) > 0 { + return nil, errors.New(strings.Join(errs, "\n")) + } - if len(errs) > 0 { - return nil, errors.New(strings.Join(errs, "\n")) + // check that we don't have both positionals and subcommands + var hasPositional bool + for _, spec := range cmd.specs { + if spec.positional { + hasPositional = true } } - if p.config.Program == "" { - p.config.Program = "program" - if len(os.Args) > 0 { - p.config.Program = filepath.Base(os.Args[0]) - } + if hasPositional && len(cmd.subcommands) > 0 { + return nil, fmt.Errorf("%s cannot have both subcommands and positional arguments", dest) } - return &p, nil + + return &cmd, nil } // Parse processes the given command line option, storing the results in the field // of the structs from which NewParser was constructed func (p *Parser) Parse(args []string) error { - // If -h or --help were specified then print usage - for _, arg := range args { - if arg == "-h" || arg == "--help" { - return ErrHelp - } - if arg == "--version" { - return ErrVersion - } - if arg == "--" { - break + err := p.process(args) + if err != nil { + // If -h or --help were specified then make sure help text supercedes other errors + for _, arg := range args { + if arg == "-h" || arg == "--help" { + return ErrHelp + } + if arg == "--" { + break + } } } - - // Process all command line arguments - return process(p.specs, args) + return err } -// process goes through arguments one-by-one, parses them, and assigns the result to -// the underlying struct field -func process(specs []*spec, args []string) error { - // track the options we have seen - wasPresent := make(map[*spec]bool) - - // construct a map from --option to spec - optionMap := make(map[string]*spec) - for _, spec := range specs { - if spec.positional { - continue - } - if spec.long != "" { - optionMap[spec.long] = spec - } - if spec.short != "" { - optionMap[spec.short] = spec - } - } - - // deal with environment vars +// process environment vars for the given arguments +func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error { for _, spec := range specs { if spec.env == "" { continue @@ -298,7 +393,7 @@ func process(specs []*spec, args []string) error { err, ) } - if err = setSlice(spec.dest, values, !spec.separate); err != nil { + if err = setSlice(p.val(spec.dest), values, !spec.separate); err != nil { return fmt.Errorf( "error processing environment variable %s with multiple values: %v", spec.env, @@ -306,13 +401,36 @@ func process(specs []*spec, args []string) error { ) } } else { - if err := scalar.ParseValue(spec.dest, value); err != nil { + if err := scalar.ParseValue(p.val(spec.dest), value); err != nil { return fmt.Errorf("error processing environment variable %s: %v", spec.env, err) } } wasPresent[spec] = true } + return nil +} + +// process goes through arguments one-by-one, parses them, and assigns the result to +// the underlying struct field +func (p *Parser) process(args []string) error { + // track the options we have seen + wasPresent := make(map[*spec]bool) + + // union of specs for the chain of subcommands encountered so far + curCmd := p.cmd + p.lastCmd = curCmd + + // make a copy of the specs because we will add to this list each time we expand a subcommand + specs := make([]*spec, len(curCmd.specs)) + copy(specs, curCmd.specs) + + // deal with environment vars + err := p.captureEnvVars(specs, wasPresent) + if err != nil { + return err + } + // process each string from the command line var allpositional bool var positionals []string @@ -326,10 +444,44 @@ func process(specs []*spec, args []string) error { } if !isFlag(arg) || allpositional { - positionals = append(positionals, arg) + // each subcommand can have either subcommands or positionals, but not both + if len(curCmd.subcommands) == 0 { + positionals = append(positionals, arg) + continue + } + + // if we have a subcommand then make sure it is valid for the current context + subcmd := findSubcommand(curCmd.subcommands, arg) + if subcmd == nil { + return fmt.Errorf("invalid subcommand: %s", arg) + } + + // instantiate the field to point to a new struct + v := p.val(subcmd.dest) + v.Set(reflect.New(v.Type().Elem())) // we already checked that all subcommands are struct pointers + + // add the new options to the set of allowed options + specs = append(specs, subcmd.specs...) + + // capture environment vars for these new options + err := p.captureEnvVars(subcmd.specs, wasPresent) + if err != nil { + return err + } + + curCmd = subcmd + p.lastCmd = curCmd continue } + // check for special --help and --version flags + switch arg { + case "-h", "--help": + return ErrHelp + case "--version": + return ErrVersion + } + // check for an equals sign, as in "--foo=bar" var value string opt := strings.TrimLeft(arg, "-") @@ -338,9 +490,10 @@ func process(specs []*spec, args []string) error { opt = opt[:pos] } - // lookup the spec for this option - spec, ok := optionMap[opt] - if !ok { + // lookup the spec for this option (note that the "specs" slice changes as + // we expand subcommands so it is better not to use a map) + spec := findOption(specs, opt) + if spec == nil { return fmt.Errorf("unknown argument %s", arg) } wasPresent[spec] = true @@ -359,7 +512,7 @@ func process(specs []*spec, args []string) error { } else { values = append(values, value) } - err := setSlice(spec.dest, values, !spec.separate) + err := setSlice(p.val(spec.dest), values, !spec.separate) if err != nil { return fmt.Errorf("error processing %s: %v", arg, err) } @@ -377,14 +530,14 @@ func process(specs []*spec, args []string) error { if i+1 == len(args) { return fmt.Errorf("missing value for %s", arg) } - if !nextIsNumeric(spec.dest.Type(), args[i+1]) && isFlag(args[i+1]) { + if !nextIsNumeric(spec.typ, args[i+1]) && isFlag(args[i+1]) { return fmt.Errorf("missing value for %s", arg) } value = args[i+1] i++ } - err := scalar.ParseValue(spec.dest, value) + err := scalar.ParseValue(p.val(spec.dest), value) if err != nil { return fmt.Errorf("error processing %s: %v", arg, err) } @@ -400,13 +553,13 @@ func process(specs []*spec, args []string) error { } wasPresent[spec] = true if spec.multiple { - err := setSlice(spec.dest, positionals, true) + err := setSlice(p.val(spec.dest), positionals, true) if err != nil { return fmt.Errorf("error processing %s: %v", spec.long, err) } positionals = nil } else { - err := scalar.ParseValue(spec.dest, positionals[0]) + err := scalar.ParseValue(p.val(spec.dest), positionals[0]) if err != nil { return fmt.Errorf("error processing %s: %v", spec.long, err) } @@ -449,6 +602,30 @@ func isFlag(s string) bool { return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != "" } +// val returns a reflect.Value corresponding to the current value for the +// given path +func (p *Parser) val(dest path) reflect.Value { + v := p.roots[dest.root] + for _, field := range dest.fields { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return reflect.Value{} + } + v = v.Elem() + } + + v = v.FieldByName(field) + if !v.IsValid() { + // it is appropriate to panic here because this can only happen due to + // an internal bug in this library (since we construct the path ourselves + // by reflecting on the same struct) + panic(fmt.Errorf("error resolving path %v: %v has no field named %v", + dest.fields, v.Type(), field)) + } + } + return v +} + // parse a value as the appropriate type and store it in the struct func setSlice(dest reflect.Value, values []string, trunc bool) error { if !dest.CanSet() { @@ -480,56 +657,25 @@ func setSlice(dest reflect.Value, values []string, trunc bool) error { return nil } -// canParse returns true if the type can be parsed from a string -func canParse(t reflect.Type) (parseable, boolean, multiple bool) { - parseable = scalar.CanParse(t) - boolean = isBoolean(t) - if parseable { - return - } - - // Look inside pointer types - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - // Look inside slice types - if t.Kind() == reflect.Slice { - multiple = true - t = t.Elem() - } - - parseable = scalar.CanParse(t) - boolean = isBoolean(t) - if parseable { - return - } - - // Look inside pointer types (again, in case of []*Type) - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - - parseable = scalar.CanParse(t) - boolean = isBoolean(t) - if parseable { - return +// findOption finds an option from its name, or returns null if no spec is found +func findOption(specs []*spec, name string) *spec { + for _, spec := range specs { + if spec.positional { + continue + } + if spec.long == name || spec.short == name { + return spec + } } - - return false, false, false + return nil } -var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem() - -// isBoolean returns true if the type can be parsed from a single string -func isBoolean(t reflect.Type) bool { - switch { - case t.Implements(textUnmarshalerType): - return false - case t.Kind() == reflect.Bool: - return true - case t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Bool: - return true - default: - return false +// findSubcommand finds a subcommand using its name, or returns null if no subcommand is found +func findSubcommand(cmds []*command, name string) *command { + for _, cmd := range cmds { + if cmd.name == name { + return cmd + } } + return nil } |
