summaryrefslogtreecommitdiff
path: root/parse.go
diff options
context:
space:
mode:
Diffstat (limited to 'parse.go')
-rw-r--r--parse.go122
1 files changed, 62 insertions, 60 deletions
diff --git a/parse.go b/parse.go
index 32b9b9d..c959656 100644
--- a/parse.go
+++ b/parse.go
@@ -6,7 +6,6 @@ import (
"os"
"path/filepath"
"reflect"
- "strconv"
"strings"
)
@@ -19,8 +18,10 @@ type spec struct {
required bool
positional bool
help string
+ env string
wasPresent bool
- isBool bool
+ boolean bool
+ fieldName string // for generating helpful errors
}
// ErrHelp indicates that -h or --help were provided
@@ -87,31 +88,19 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
}
spec := spec{
- long: strings.ToLower(field.Name),
- dest: v.Field(i),
+ long: strings.ToLower(field.Name),
+ dest: v.Field(i),
+ fieldName: t.Name() + "." + field.Name,
}
- // Get the scalar type for this field
- scalarType := field.Type
- if scalarType.Kind() == reflect.Slice {
- spec.multiple = true
- scalarType = scalarType.Elem()
- if scalarType.Kind() == reflect.Ptr {
- scalarType = scalarType.Elem()
- }
- }
-
- // Check for unsupported types
- switch scalarType.Kind() {
- case reflect.Array, reflect.Chan, reflect.Func, reflect.Interface,
- reflect.Map, reflect.Ptr, reflect.Struct,
- reflect.Complex64, reflect.Complex128:
- return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, scalarType.Kind())
- }
-
- // Specify that it is a bool for usage
- if scalarType.Kind() == reflect.Bool {
- spec.isBool = true
+ // Check whether this field is supported. It's good to do this here rather than
+ // wait until setScalar because it means that a program with invalid argument
+ // fields will always fail regardless of whether the arguments it recieved happend
+ // to exercise those fields.
+ var parseable bool
+ parseable, spec.boolean, spec.multiple = canParse(field.Type)
+ if !parseable {
+ return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, field.Type.String())
}
// Look at the tag
@@ -137,6 +126,13 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
spec.positional = true
case key == "help":
spec.help = value
+ case key == "env":
+ // Use override name if provided
+ if value != "" {
+ spec.env = value
+ } else {
+ spec.env = strings.ToUpper(field.Name)
+ }
default:
return nil, fmt.Errorf("unrecognized tag '%s' on field %s", key, tag)
}
@@ -195,6 +191,15 @@ func process(specs []*spec, args []string) error {
if spec.short != "" {
optionMap[spec.short] = spec
}
+ if spec.env != "" {
+ if value, found := os.LookupEnv(spec.env); found {
+ err := setScalar(spec.dest, value)
+ if err != nil {
+ return fmt.Errorf("error processing environment variable %s: %v", spec.env, err)
+ }
+ spec.wasPresent = true
+ }
+ }
}
// process each string from the command line
@@ -248,7 +253,8 @@ func process(specs []*spec, args []string) error {
}
// if it's a flag and it has no value then set the value to true
- if spec.dest.Kind() == reflect.Bool && value == "" {
+ // use boolean because this takes account of TextUnmarshaler
+ if spec.boolean && value == "" {
value = "true"
}
@@ -329,41 +335,37 @@ func setSlice(dest reflect.Value, values []string) error {
return nil
}
-// set a value from a string
-func setScalar(v reflect.Value, s string) error {
- if !v.CanSet() {
- return fmt.Errorf("field is not exported")
+// canParse returns true if the type can be parsed from a string
+func canParse(t reflect.Type) (parseable, boolean, multiple bool) {
+ parseable, boolean = isScalar(t)
+ if parseable {
+ return
}
- switch v.Kind() {
- case reflect.String:
- v.Set(reflect.ValueOf(s))
- case reflect.Bool:
- x, err := strconv.ParseBool(s)
- if err != nil {
- return err
- }
- v.Set(reflect.ValueOf(x))
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- x, err := strconv.ParseInt(s, 10, v.Type().Bits())
- if err != nil {
- return err
- }
- v.Set(reflect.ValueOf(x).Convert(v.Type()))
- case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
- x, err := strconv.ParseUint(s, 10, v.Type().Bits())
- if err != nil {
- return err
- }
- v.Set(reflect.ValueOf(x).Convert(v.Type()))
- case reflect.Float32, reflect.Float64:
- x, err := strconv.ParseFloat(s, v.Type().Bits())
- if err != nil {
- return err
- }
- v.Set(reflect.ValueOf(x).Convert(v.Type()))
- default:
- return fmt.Errorf("not a scalar type: %s", v.Kind())
+ // Look inside pointer types
+ if t.Kind() == reflect.Ptr {
+ t = t.Elem()
}
- return nil
+ // Look inside slice types
+ if t.Kind() == reflect.Slice {
+ multiple = true
+ t = t.Elem()
+ }
+
+ parseable, boolean = isScalar(t)
+ if parseable {
+ return
+ }
+
+ // Look inside pointer types (again, in case of []*Type)
+ if t.Kind() == reflect.Ptr {
+ t = t.Elem()
+ }
+
+ parseable, boolean = isScalar(t)
+ if parseable {
+ return
+ }
+
+ return false, false, false
}