summaryrefslogtreecommitdiff
path: root/parse.go
diff options
context:
space:
mode:
authorAlex Flint <[email protected]>2015-10-31 18:26:58 -0700
committerAlex Flint <[email protected]>2015-10-31 18:26:58 -0700
commitb9ad104f3301e7d078cd9ba16410eae3f6e772aa (patch)
tree1babe040c11f8ff40f6ec65aadab0ab13d4562c9 /parse.go
parent8397a40f4cafd39c553df848854e022d33149fa5 (diff)
added usage generation
Diffstat (limited to 'parse.go')
-rw-r--r--parse.go193
1 files changed, 101 insertions, 92 deletions
diff --git a/parse.go b/parse.go
index 74b9175..d9382c9 100644
--- a/parse.go
+++ b/parse.go
@@ -1,17 +1,28 @@
-package arguments
+package arg
import (
"fmt"
- "log"
"os"
"reflect"
"strconv"
"strings"
)
+// spec represents a command line option
+type spec struct {
+ dest reflect.Value
+ long string
+ short string
+ multiple bool
+ required bool
+ positional bool
+ help string
+ wasPresent bool
+}
+
// MustParse processes command line arguments and exits upon failure.
-func MustParse(dest interface{}) {
- err := Parse(dest)
+func MustParse(dest ...interface{}) {
+ err := Parse(dest...)
if err != nil {
fmt.Println(err)
os.Exit(1)
@@ -19,122 +30,121 @@ func MustParse(dest interface{}) {
}
// Parse processes command line arguments and stores the result in args.
-func Parse(dest interface{}) error {
- return ParseFrom(dest, os.Args)
+func Parse(dest ...interface{}) error {
+ return ParseFrom(os.Args[1:], dest...)
}
// ParseFrom processes command line arguments and stores the result in args.
-func ParseFrom(dest interface{}, args []string) error {
- v := reflect.ValueOf(dest)
- if v.Kind() != reflect.Ptr {
- panic(fmt.Sprintf("%s is not a pointer type", v.Type().Name()))
+func ParseFrom(args []string, dest ...interface{}) error {
+ // Add the help option if one is not already defined
+ var internal struct {
+ Help bool `arg:"-h"`
}
- v = v.Elem()
// Parse the spec
- spec, err := extractSpec(v.Type())
+ dest = append(dest, &internal)
+ spec, err := extractSpec(dest...)
if err != nil {
return err
}
// Process args
- err = processArgs(v, spec, args)
+ err = processArgs(spec, args)
if err != nil {
return err
}
+ // If -h or --help were specified then print help
+ if internal.Help {
+ writeUsage(os.Stdout, spec)
+ os.Exit(0)
+ }
+
// Validate
return validate(spec)
}
-// spec represents information about an argument extracted from struct tags
-type spec struct {
- field reflect.StructField
- index int
- long string
- short string
- multiple bool
- required bool
- positional bool
- help string
- wasPresent bool
-}
-
// extractSpec gets specifications for each argument from the tags in a struct
-func extractSpec(t reflect.Type) ([]*spec, error) {
- if t.Kind() != reflect.Struct {
- panic(fmt.Sprintf("%s is not a struct pointer", t.Name()))
- }
-
+func extractSpec(dests ...interface{}) ([]*spec, error) {
var specs []*spec
- for i := 0; i < t.NumField(); i++ {
- // Check for the ignore switch in the tag
- field := t.Field(i)
- tag := field.Tag.Get("arg")
- if tag == "-" {
- continue
+ for _, dest := range dests {
+ v := reflect.ValueOf(dest)
+ if v.Kind() != reflect.Ptr {
+ panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", v.Type()))
}
-
- spec := spec{
- long: strings.ToLower(field.Name),
- field: field,
- index: i,
+ v = v.Elem()
+ if v.Kind() != reflect.Struct {
+ panic(fmt.Sprintf("%T is not a struct pointer", dest))
}
- // Get the scalar type for this field
- scalarType := field.Type
- log.Println(field.Name, field.Type, field.Type.Kind())
- if scalarType.Kind() == reflect.Slice {
- spec.multiple = true
- scalarType = scalarType.Elem()
- if scalarType.Kind() == reflect.Ptr {
- scalarType = scalarType.Elem()
+ t := v.Type()
+ for i := 0; i < t.NumField(); i++ {
+ // Check for the ignore switch in the tag
+ field := t.Field(i)
+ tag := field.Tag.Get("arg")
+ if tag == "-" {
+ continue
}
- }
- // Check for unsupported types
- switch scalarType.Kind() {
- case reflect.Array, reflect.Chan, reflect.Func, reflect.Interface,
- reflect.Map, reflect.Ptr, reflect.Struct,
- reflect.Complex64, reflect.Complex128:
- return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, scalarType.Kind())
- }
+ spec := spec{
+ long: strings.ToLower(field.Name),
+ dest: v.Field(i),
+ }
- // Look at the tag
- if tag != "" {
- for _, key := range strings.Split(tag, ",") {
- var value string
- if pos := strings.Index(key, ":"); pos != -1 {
- value = key[pos+1:]
- key = key[:pos]
+ // Get the scalar type for this field
+ scalarType := field.Type
+ if scalarType.Kind() == reflect.Slice {
+ spec.multiple = true
+ scalarType = scalarType.Elem()
+ if scalarType.Kind() == reflect.Ptr {
+ scalarType = scalarType.Elem()
}
+ }
+
+ // Check for unsupported types
+ switch scalarType.Kind() {
+ case reflect.Array, reflect.Chan, reflect.Func, reflect.Interface,
+ reflect.Map, reflect.Ptr, reflect.Struct,
+ reflect.Complex64, reflect.Complex128:
+ return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, scalarType.Kind())
+ }
+
+ // Look at the tag
+ if tag != "" {
+ for _, key := range strings.Split(tag, ",") {
+ var value string
+ if pos := strings.Index(key, ":"); pos != -1 {
+ value = key[pos+1:]
+ key = key[:pos]
+ }
- switch {
- case strings.HasPrefix(key, "--"):
- spec.long = key[2:]
- case strings.HasPrefix(key, "-"):
- if len(key) != 2 {
- return nil, fmt.Errorf("%s.%s: short arguments must be one character only", t.Name(), field.Name)
+ switch {
+ case strings.HasPrefix(key, "--"):
+ spec.long = key[2:]
+ case strings.HasPrefix(key, "-"):
+ if len(key) != 2 {
+ return nil, fmt.Errorf("%s.%s: short arguments must be one character only", t.Name(), field.Name)
+ }
+ spec.short = key[1:]
+ case key == "required":
+ spec.required = true
+ case key == "positional":
+ spec.positional = true
+ case key == "help":
+ spec.help = value
+ default:
+ return nil, fmt.Errorf("unrecognized tag '%s' on field %s", key, tag)
}
- spec.short = key[1:]
- case key == "required":
- spec.required = true
- case key == "positional":
- spec.positional = true
- case key == "help":
- spec.help = value
- default:
- return nil, fmt.Errorf("unrecognized tag '%s' on field %s", key, tag)
}
}
+ specs = append(specs, &spec)
}
- specs = append(specs, &spec)
}
return specs, nil
}
// processArgs processes arguments using a pre-constructed spec
-func processArgs(dest reflect.Value, specs []*spec, args []string) error {
+func processArgs(specs []*spec, args []string) error {
// construct a map from --option to spec
optionMap := make(map[string]*spec)
for _, spec := range specs {
@@ -192,7 +202,7 @@ func processArgs(dest reflect.Value, specs []*spec, args []string) error {
} else {
values = append(values, value)
}
- err := setSlice(dest.Field(spec.index), values)
+ err := setSlice(spec.dest, values)
if err != nil {
return fmt.Errorf("error processing %s: %v", arg, err)
}
@@ -200,7 +210,7 @@ func processArgs(dest reflect.Value, specs []*spec, args []string) error {
}
// if it's a flag and it has no value then set the value to true
- if spec.field.Type.Kind() == reflect.Bool && value == "" {
+ if spec.dest.Kind() == reflect.Bool && value == "" {
value = "true"
}
@@ -213,7 +223,7 @@ func processArgs(dest reflect.Value, specs []*spec, args []string) error {
i++
}
- err := setScalar(dest.Field(spec.index), value)
+ err := setScalar(spec.dest, value)
if err != nil {
return fmt.Errorf("error processing %s: %v", arg, err)
}
@@ -221,22 +231,21 @@ func processArgs(dest reflect.Value, specs []*spec, args []string) error {
// process positionals
for _, spec := range specs {
- label := strings.ToLower(spec.field.Name)
if spec.positional {
if spec.multiple {
- err := setSlice(dest.Field(spec.index), positionals)
+ err := setSlice(spec.dest, positionals)
if err != nil {
- return fmt.Errorf("error processing %s: %v", label, err)
+ return fmt.Errorf("error processing %s: %v", spec.long, err)
}
positionals = nil
} else if len(positionals) > 0 {
- err := setScalar(dest.Field(spec.index), positionals[0])
+ err := setScalar(spec.dest, positionals[0])
if err != nil {
- return fmt.Errorf("error processing %s: %v", label, err)
+ return fmt.Errorf("error processing %s: %v", spec.long, err)
}
positionals = positionals[1:]
} else if spec.required {
- return fmt.Errorf("%s is required", label)
+ return fmt.Errorf("%s is required", spec.long)
}
}
}
@@ -250,7 +259,7 @@ func processArgs(dest reflect.Value, specs []*spec, args []string) error {
func validate(spec []*spec) error {
for _, arg := range spec {
if !arg.positional && arg.required && !arg.wasPresent {
- return fmt.Errorf("--%s is required", strings.ToLower(arg.field.Name))
+ return fmt.Errorf("--%s is required", arg.long)
}
}
return nil