diff options
Diffstat (limited to 'parse.go')
| -rw-r--r-- | parse.go | 122 |
1 files changed, 62 insertions, 60 deletions
@@ -6,7 +6,6 @@ import ( "os" "path/filepath" "reflect" - "strconv" "strings" ) @@ -19,8 +18,10 @@ type spec struct { required bool positional bool help string + env string wasPresent bool - isBool bool + boolean bool + fieldName string // for generating helpful errors } // ErrHelp indicates that -h or --help were provided @@ -87,31 +88,19 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) { } spec := spec{ - long: strings.ToLower(field.Name), - dest: v.Field(i), + long: strings.ToLower(field.Name), + dest: v.Field(i), + fieldName: t.Name() + "." + field.Name, } - // Get the scalar type for this field - scalarType := field.Type - if scalarType.Kind() == reflect.Slice { - spec.multiple = true - scalarType = scalarType.Elem() - if scalarType.Kind() == reflect.Ptr { - scalarType = scalarType.Elem() - } - } - - // Check for unsupported types - switch scalarType.Kind() { - case reflect.Array, reflect.Chan, reflect.Func, reflect.Interface, - reflect.Map, reflect.Ptr, reflect.Struct, - reflect.Complex64, reflect.Complex128: - return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, scalarType.Kind()) - } - - // Specify that it is a bool for usage - if scalarType.Kind() == reflect.Bool { - spec.isBool = true + // Check whether this field is supported. It's good to do this here rather than + // wait until setScalar because it means that a program with invalid argument + // fields will always fail regardless of whether the arguments it recieved happend + // to exercise those fields. + var parseable bool + parseable, spec.boolean, spec.multiple = canParse(field.Type) + if !parseable { + return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, field.Type.String()) } // Look at the tag @@ -137,6 +126,13 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) { spec.positional = true case key == "help": spec.help = value + case key == "env": + // Use override name if provided + if value != "" { + spec.env = value + } else { + spec.env = strings.ToUpper(field.Name) + } default: return nil, fmt.Errorf("unrecognized tag '%s' on field %s", key, tag) } @@ -195,6 +191,15 @@ func process(specs []*spec, args []string) error { if spec.short != "" { optionMap[spec.short] = spec } + if spec.env != "" { + if value, found := os.LookupEnv(spec.env); found { + err := setScalar(spec.dest, value) + if err != nil { + return fmt.Errorf("error processing environment variable %s: %v", spec.env, err) + } + spec.wasPresent = true + } + } } // process each string from the command line @@ -248,7 +253,8 @@ func process(specs []*spec, args []string) error { } // if it's a flag and it has no value then set the value to true - if spec.dest.Kind() == reflect.Bool && value == "" { + // use boolean because this takes account of TextUnmarshaler + if spec.boolean && value == "" { value = "true" } @@ -329,41 +335,37 @@ func setSlice(dest reflect.Value, values []string) error { return nil } -// set a value from a string -func setScalar(v reflect.Value, s string) error { - if !v.CanSet() { - return fmt.Errorf("field is not exported") +// canParse returns true if the type can be parsed from a string +func canParse(t reflect.Type) (parseable, boolean, multiple bool) { + parseable, boolean = isScalar(t) + if parseable { + return } - switch v.Kind() { - case reflect.String: - v.Set(reflect.ValueOf(s)) - case reflect.Bool: - x, err := strconv.ParseBool(s) - if err != nil { - return err - } - v.Set(reflect.ValueOf(x)) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - x, err := strconv.ParseInt(s, 10, v.Type().Bits()) - if err != nil { - return err - } - v.Set(reflect.ValueOf(x).Convert(v.Type())) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - x, err := strconv.ParseUint(s, 10, v.Type().Bits()) - if err != nil { - return err - } - v.Set(reflect.ValueOf(x).Convert(v.Type())) - case reflect.Float32, reflect.Float64: - x, err := strconv.ParseFloat(s, v.Type().Bits()) - if err != nil { - return err - } - v.Set(reflect.ValueOf(x).Convert(v.Type())) - default: - return fmt.Errorf("not a scalar type: %s", v.Kind()) + // Look inside pointer types + if t.Kind() == reflect.Ptr { + t = t.Elem() } - return nil + // Look inside slice types + if t.Kind() == reflect.Slice { + multiple = true + t = t.Elem() + } + + parseable, boolean = isScalar(t) + if parseable { + return + } + + // Look inside pointer types (again, in case of []*Type) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + parseable, boolean = isScalar(t) + if parseable { + return + } + + return false, false, false } |
