summaryrefslogtreecommitdiff
path: root/parse.go
diff options
context:
space:
mode:
authorAlex Flint <[email protected]>2015-10-31 16:15:24 -0700
committerAlex Flint <[email protected]>2015-10-31 16:15:24 -0700
commit408290f7c2a968a0de255813e125a9ebb0a9dda6 (patch)
treeb1fe03f40b30315ffe000572a037ed6adf91e30e /parse.go
basic first version working
Diffstat (limited to 'parse.go')
-rw-r--r--parse.go268
1 files changed, 268 insertions, 0 deletions
diff --git a/parse.go b/parse.go
new file mode 100644
index 0000000..b58b51e
--- /dev/null
+++ b/parse.go
@@ -0,0 +1,268 @@
+package arguments
+
+import (
+ "fmt"
+ "os"
+ "reflect"
+ "strconv"
+ "strings"
+)
+
+// MustParse processes command line arguments and exits upon failure.
+func MustParse(dest interface{}) {
+ err := Parse(dest)
+ if err != nil {
+ fmt.Println(err)
+ os.Exit(1)
+ }
+}
+
+// Parse processes command line arguments and stores the result in args.
+func Parse(dest interface{}) error {
+ return ParseFrom(dest, os.Args)
+}
+
+// ParseFrom processes command line arguments and stores the result in args.
+func ParseFrom(dest interface{}, args []string) error {
+ v := reflect.ValueOf(dest)
+ if v.Kind() != reflect.Ptr {
+ panic(fmt.Sprintf("%s is not a pointer type", v.Type().Name()))
+ }
+ v = v.Elem()
+
+ // Parse the spec
+ spec, err := extractSpec(v.Type())
+ if err != nil {
+ return err
+ }
+
+ // Process args
+ err = processArgs(v, spec, args)
+ if err != nil {
+ return err
+ }
+
+ // Validate
+ return validate(spec)
+}
+
+// spec represents information about an argument extracted from struct tags
+type spec struct {
+ field reflect.StructField
+ index int
+ long string
+ short string
+ multiple bool
+ required bool
+ positional bool
+ help string
+ wasPresent bool
+}
+
+// extractSpec gets specifications for each argument from the tags in a struct
+func extractSpec(t reflect.Type) ([]*spec, error) {
+ if t.Kind() != reflect.Struct {
+ panic(fmt.Sprintf("%s is not a struct pointer", t.Name()))
+ }
+
+ var specs []*spec
+ for i := 0; i < t.NumField(); i++ {
+ // Check for the ignore switch in the tag
+ field := t.Field(i)
+ tag := field.Tag.Get("arg")
+ if tag == "-" {
+ continue
+ }
+
+ spec := spec{
+ long: strings.ToLower(field.Name),
+ field: field,
+ index: i,
+ }
+
+ // Get the scalar type for this field
+ scalarType := field.Type
+ if scalarType.Kind() == reflect.Slice {
+ spec.multiple = true
+ scalarType = scalarType.Elem()
+ if scalarType.Kind() == reflect.Ptr {
+ scalarType = scalarType.Elem()
+ }
+ }
+
+ // Check for unsupported types
+ switch scalarType.Kind() {
+ case reflect.Array, reflect.Chan, reflect.Func, reflect.Interface,
+ reflect.Map, reflect.Ptr, reflect.Struct,
+ reflect.Complex64, reflect.Complex128:
+ return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, scalarType.Kind())
+ }
+
+ // Look at the tag
+ if tag != "" {
+ for _, key := range strings.Split(tag, ",") {
+ var value string
+ if pos := strings.Index(key, ":"); pos != -1 {
+ value = key[pos+1:]
+ key = key[:pos]
+ }
+
+ switch {
+ case strings.HasPrefix(key, "--"):
+ spec.long = key[2:]
+ case strings.HasPrefix(key, "-"):
+ if len(key) != 2 {
+ return nil, fmt.Errorf("%s.%s: short arguments must be one character only", t.Name(), field.Name)
+ }
+ spec.short = key[1:]
+ case key == "required":
+ spec.required = true
+ case key == "positional":
+ spec.positional = true
+ case key == "help":
+ spec.help = value
+ default:
+ return nil, fmt.Errorf("unrecognized tag '%s' on field %s", key, tag)
+ }
+ }
+ }
+ specs = append(specs, &spec)
+ }
+ return specs, nil
+}
+
+// processArgs processes arguments using a pre-constructed spec
+func processArgs(dest reflect.Value, specs []*spec, args []string) error {
+ // construct a map from arg name to spec
+ specByName := make(map[string]*spec)
+ for _, spec := range specs {
+ if spec.long != "" {
+ specByName[spec.long] = spec
+ }
+ if spec.short != "" {
+ specByName[spec.short] = spec
+ }
+ }
+
+ // process each string from the command line
+ var allpositional bool
+ var positionals []string
+
+ // must use explicit for loop, not range, because we manipulate i inside the loop
+ for i := 0; i < len(args); i++ {
+ arg := args[i]
+ if arg == "--" {
+ allpositional = true
+ continue
+ }
+
+ if !strings.HasPrefix(arg, "-") || allpositional {
+ positionals = append(positionals, arg)
+ continue
+ }
+
+ // check for an equals sign, as in "--foo=bar"
+ var value string
+ opt := strings.TrimLeft(arg, "-")
+ if pos := strings.Index(opt, "="); pos != -1 {
+ value = opt[pos+1:]
+ opt = opt[:pos]
+ }
+
+ // lookup the spec for this option
+ spec, ok := specByName[opt]
+ if !ok {
+ return fmt.Errorf("unknown argument %s", arg)
+ }
+ spec.wasPresent = true
+
+ // deal with the case of multiple values
+ if spec.multiple {
+ var values []string
+ if value == "" {
+ for i++; i < len(args) && !strings.HasPrefix(args[i], "-"); i++ {
+ values = append(values, args[i])
+ }
+ } else {
+ values = append(values, value)
+ }
+ setSlice(dest, spec, values)
+ continue
+ }
+
+ // if it's a flag and it has no value then set the value to true
+ if spec.field.Type.Kind() == reflect.Bool && value == "" {
+ value = "true"
+ }
+
+ // if we have something like "--foo" then the value is the next argument
+ if value == "" {
+ if i+1 == len(args) || strings.HasPrefix(args[i+1], "-") {
+ return fmt.Errorf("missing value for %s", arg)
+ }
+ value = args[i+1]
+ i++
+ }
+
+ err := setScalar(dest.Field(spec.index), value)
+ if err != nil {
+ return fmt.Errorf("error processing %s: %v", arg, err)
+ }
+ }
+ return nil
+}
+
+// validate an argument spec after arguments have been parse
+func validate(spec []*spec) error {
+ for _, arg := range spec {
+ if arg.required && !arg.wasPresent {
+ return fmt.Errorf("--%s is required", strings.ToLower(arg.field.Name))
+ }
+ }
+ return nil
+}
+
+// parse a value as the apropriate type and store it in the struct
+func setSlice(dest reflect.Value, spec *spec, values []string) error {
+ // TODO
+ return nil
+}
+
+// set a value from a string
+func setScalar(v reflect.Value, s string) error {
+ if !v.CanSet() {
+ return fmt.Errorf("field is not writable")
+ }
+
+ switch v.Kind() {
+ case reflect.String:
+ v.Set(reflect.ValueOf(s))
+ case reflect.Bool:
+ x, err := strconv.ParseBool(s)
+ if err != nil {
+ return err
+ }
+ v.Set(reflect.ValueOf(x))
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ x, err := strconv.ParseInt(s, 10, v.Type().Bits())
+ if err != nil {
+ return err
+ }
+ v.Set(reflect.ValueOf(x).Convert(v.Type()))
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ x, err := strconv.ParseUint(s, 10, v.Type().Bits())
+ if err != nil {
+ return err
+ }
+ v.Set(reflect.ValueOf(x).Convert(v.Type()))
+ case reflect.Float32, reflect.Float64:
+ x, err := strconv.ParseFloat(s, v.Type().Bits())
+ if err != nil {
+ return err
+ }
+ v.Set(reflect.ValueOf(x).Convert(v.Type()))
+ default:
+ return fmt.Errorf("not a scalar type: %s", v.Kind())
+ }
+ return nil
+}