diff options
Diffstat (limited to 'parse.go')
| -rw-r--r-- | parse.go | 57 |
1 files changed, 37 insertions, 20 deletions
@@ -1,6 +1,7 @@ package arg import ( + "encoding" "errors" "fmt" "os" @@ -20,11 +21,15 @@ type spec struct { env string wasPresent bool isBool bool + fieldName string // for generating helpful errors } // ErrHelp indicates that -h or --help were provided var ErrHelp = errors.New("help requested by user") +// The TextUnmarshaler type in reflection form +var textUnsmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem() + // MustParse processes command line arguments and exits upon failure func MustParse(dest ...interface{}) *Parser { p, err := NewParser(dest...) @@ -80,31 +85,42 @@ func NewParser(dests ...interface{}) (*Parser, error) { } spec := spec{ - long: strings.ToLower(field.Name), - dest: v.Field(i), + long: strings.ToLower(field.Name), + dest: v.Field(i), + fieldName: t.Name() + "." + field.Name, } - // Get the scalar type for this field - scalarType := field.Type - if scalarType.Kind() == reflect.Slice { - spec.multiple = true - scalarType = scalarType.Elem() + // Check whether this field is supported. It's good to do this here rather than + // wait until setScalar because it means that a program with invalid argument + // fields will always fail regardless of whether the arguments it recieved happend + // to exercise those fields. + if !field.Type.Implements(textUnsmarshalerType) { + scalarType := field.Type + // Look inside pointer types + if scalarType.Kind() == reflect.Ptr { + scalarType = scalarType.Elem() + } + // Check for bool + if scalarType.Kind() == reflect.Bool { + spec.isBool = true + } + // Look inside slice types + if scalarType.Kind() == reflect.Slice { + spec.multiple = true + scalarType = scalarType.Elem() + } + // Look inside pointer types (again, in case of []*Type) if scalarType.Kind() == reflect.Ptr { scalarType = scalarType.Elem() } - } - - // Check for unsupported types - switch scalarType.Kind() { - case reflect.Array, reflect.Chan, reflect.Func, reflect.Interface, - reflect.Map, reflect.Ptr, reflect.Struct, - reflect.Complex64, reflect.Complex128: - return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, scalarType.Kind()) - } - // Specify that it is a bool for usage - if scalarType.Kind() == reflect.Bool { - spec.isBool = true + // Check for unsupported types + switch scalarType.Kind() { + case reflect.Array, reflect.Chan, reflect.Func, reflect.Interface, + reflect.Map, reflect.Ptr, reflect.Struct, + reflect.Complex64, reflect.Complex128: + return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, scalarType.Kind()) + } } // Look at the tag @@ -248,7 +264,8 @@ func process(specs []*spec, args []string) error { } // if it's a flag and it has no value then set the value to true - if spec.dest.Kind() == reflect.Bool && value == "" { + // use isBool because this takes account of TextUnmarshaler + if spec.isBool && value == "" { value = "true" } |
