summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md57
-rw-r--r--parse.go57
-rw-r--r--parse_test.go91
-rw-r--r--scalar.go31
4 files changed, 201 insertions, 35 deletions
diff --git a/README.md b/README.md
index 3d1d12f..28ff388 100644
--- a/README.md
+++ b/README.md
@@ -4,6 +4,10 @@
## Structured argument parsing for Go
+```shell
+go get github.com/alexflint/go-arg
+```
+
Declare the command line arguments your program accepts by defining a struct.
```go
@@ -24,16 +28,16 @@ hello true
```go
var args struct {
- Foo string `arg:"required"`
- Bar bool
+ ID int `arg:"required"`
+ Timeout time.Duration
}
arg.MustParse(&args)
```
```shell
$ ./example
-usage: example --foo FOO [--bar]
-error: --foo is required
+usage: example --id ID [--timeout TIMEOUT]
+error: --id is required
```
### Positional arguments
@@ -161,10 +165,51 @@ usage: samples [--foo FOO] [--bar BAR]
error: you must provide one of --foo and --bar
```
-### Installation
+### Custom parsing
+
+You can implement your own argument parser by implementing `encoding.TextUnmarshaler`:
+
+```go
+package main
+
+import (
+ "fmt"
+ "strings"
+ "github.com/alexflint/go-arg"
+)
+
+// Accepts command line arguments of the form "head.tail"
+type NameDotName struct {
+ Head, Tail string
+}
+
+func (n *NameDotName) UnmarshalText(b []byte) error {
+ s := string(b)
+ pos := strings.Index(s, ".")
+ if pos == -1 {
+ return fmt.Errorf("missing period in %s", s)
+ }
+ n.Head = s[:pos]
+ n.Tail = s[pos+1:]
+ return nil
+}
+
+func main() {
+ var args struct {
+ Name *NameDotName
+ }
+ arg.MustParse(&args)
+ fmt.Printf("%#v\n", args.Name)
+}
+```
```shell
-go get github.com/alexflint/go-arg
+$ ./example --name=foo.bar
+&main.NameDotName{Head:"foo", Tail:"bar"}
+
+$ ./example --name=oops
+usage: example [--name NAME]
+error: error processing --name: missing period in "oops"
```
### Documentation
diff --git a/parse.go b/parse.go
index 39eb52c..3895ce9 100644
--- a/parse.go
+++ b/parse.go
@@ -1,6 +1,7 @@
package arg
import (
+ "encoding"
"errors"
"fmt"
"os"
@@ -20,11 +21,15 @@ type spec struct {
env string
wasPresent bool
isBool bool
+ fieldName string // for generating helpful errors
}
// ErrHelp indicates that -h or --help were provided
var ErrHelp = errors.New("help requested by user")
+// The TextUnmarshaler type in reflection form
+var textUnsmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
+
// MustParse processes command line arguments and exits upon failure
func MustParse(dest ...interface{}) *Parser {
p, err := NewParser(dest...)
@@ -80,31 +85,42 @@ func NewParser(dests ...interface{}) (*Parser, error) {
}
spec := spec{
- long: strings.ToLower(field.Name),
- dest: v.Field(i),
+ long: strings.ToLower(field.Name),
+ dest: v.Field(i),
+ fieldName: t.Name() + "." + field.Name,
}
- // Get the scalar type for this field
- scalarType := field.Type
- if scalarType.Kind() == reflect.Slice {
- spec.multiple = true
- scalarType = scalarType.Elem()
+ // Check whether this field is supported. It's good to do this here rather than
+ // wait until setScalar because it means that a program with invalid argument
+ // fields will always fail regardless of whether the arguments it recieved happend
+ // to exercise those fields.
+ if !field.Type.Implements(textUnsmarshalerType) {
+ scalarType := field.Type
+ // Look inside pointer types
+ if scalarType.Kind() == reflect.Ptr {
+ scalarType = scalarType.Elem()
+ }
+ // Check for bool
+ if scalarType.Kind() == reflect.Bool {
+ spec.isBool = true
+ }
+ // Look inside slice types
+ if scalarType.Kind() == reflect.Slice {
+ spec.multiple = true
+ scalarType = scalarType.Elem()
+ }
+ // Look inside pointer types (again, in case of []*Type)
if scalarType.Kind() == reflect.Ptr {
scalarType = scalarType.Elem()
}
- }
-
- // Check for unsupported types
- switch scalarType.Kind() {
- case reflect.Array, reflect.Chan, reflect.Func, reflect.Interface,
- reflect.Map, reflect.Ptr, reflect.Struct,
- reflect.Complex64, reflect.Complex128:
- return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, scalarType.Kind())
- }
- // Specify that it is a bool for usage
- if scalarType.Kind() == reflect.Bool {
- spec.isBool = true
+ // Check for unsupported types
+ switch scalarType.Kind() {
+ case reflect.Array, reflect.Chan, reflect.Func, reflect.Interface,
+ reflect.Map, reflect.Ptr, reflect.Struct,
+ reflect.Complex64, reflect.Complex128:
+ return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, scalarType.Kind())
+ }
}
// Look at the tag
@@ -248,7 +264,8 @@ func process(specs []*spec, args []string) error {
}
// if it's a flag and it has no value then set the value to true
- if spec.dest.Kind() == reflect.Bool && value == "" {
+ // use isBool because this takes account of TextUnmarshaler
+ if spec.isBool && value == "" {
value = "true"
}
diff --git a/parse_test.go b/parse_test.go
index c30809d..a915910 100644
--- a/parse_test.go
+++ b/parse_test.go
@@ -15,7 +15,11 @@ func parse(cmdline string, dest interface{}) error {
if err != nil {
return err
}
- return p.Parse(strings.Split(cmdline, " "))
+ var parts []string
+ if len(cmdline) > 0 {
+ parts = strings.Split(cmdline, " ")
+ }
+ return p.Parse(parts)
}
func TestString(t *testing.T) {
@@ -71,6 +75,25 @@ func TestInvalidDuration(t *testing.T) {
require.Error(t, err)
}
+func TestIntPtr(t *testing.T) {
+ var args struct {
+ Foo *int
+ }
+ err := parse("--foo 123", &args)
+ require.NoError(t, err)
+ require.NotNil(t, args.Foo)
+ assert.Equal(t, 123, *args.Foo)
+}
+
+func TestIntPtrNotPresent(t *testing.T) {
+ var args struct {
+ Foo *int
+ }
+ err := parse("", &args)
+ require.NoError(t, err)
+ assert.Nil(t, args.Foo)
+}
+
func TestMixed(t *testing.T) {
var args struct {
Foo string `arg:"-f"`
@@ -362,6 +385,14 @@ func TestUnsupportedSliceElement(t *testing.T) {
var args struct {
Foo []interface{}
}
+ err := parse("--foo 3", &args)
+ assert.Error(t, err)
+}
+
+func TestUnsupportedSliceElementMissingValue(t *testing.T) {
+ var args struct {
+ Foo []interface{}
+ }
err := parse("--foo", &args)
assert.Error(t, err)
}
@@ -452,3 +483,61 @@ func TestEnvironmentVariableRequired(t *testing.T) {
MustParse(&args)
assert.Equal(t, "bar", args.Foo)
}
+
+type textUnmarshaler struct {
+ val int
+}
+
+func (f *textUnmarshaler) UnmarshalText(b []byte) error {
+ f.val = len(b)
+ return nil
+}
+
+func TestTextUnmarshaler(t *testing.T) {
+ // fields that implement TextUnmarshaler should be parsed using that interface
+ var args struct {
+ Foo *textUnmarshaler
+ }
+ err := parse("--foo abc", &args)
+ require.NoError(t, err)
+ assert.Equal(t, 3, args.Foo.val)
+}
+
+type boolUnmarshaler bool
+
+func (p *boolUnmarshaler) UnmarshalText(b []byte) error {
+ *p = len(b)%2 == 0
+ return nil
+}
+
+func TestBoolUnmarhsaler(t *testing.T) {
+ // test that a bool type that implements TextUnmarshaler is
+ // handled as a TextUnmarshaler not as a bool
+ var args struct {
+ Foo *boolUnmarshaler
+ }
+ err := parse("--foo ab", &args)
+ require.NoError(t, err)
+ assert.EqualValues(t, true, *args.Foo)
+}
+
+type sliceUnmarshaler []int
+
+func (p *sliceUnmarshaler) UnmarshalText(b []byte) error {
+ *p = sliceUnmarshaler{len(b)}
+ return nil
+}
+
+func TestSliceUnmarhsaler(t *testing.T) {
+ // test that a slice type that implements TextUnmarshaler is
+ // handled as a TextUnmarshaler not as a slice
+ var args struct {
+ Foo *sliceUnmarshaler
+ Bar string `arg:"positional"`
+ }
+ err := parse("--foo abcde xyz", &args)
+ require.NoError(t, err)
+ require.Len(t, *args.Foo, 1)
+ assert.EqualValues(t, 5, (*args.Foo)[0])
+ assert.Equal(t, "xyz", args.Bar)
+}
diff --git a/scalar.go b/scalar.go
index a3bafe4..67b4540 100644
--- a/scalar.go
+++ b/scalar.go
@@ -8,19 +8,33 @@ import (
"time"
)
-var (
- durationType = reflect.TypeOf(time.Duration(0))
- textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
-)
-
// set a value from a string
func setScalar(v reflect.Value, s string) error {
if !v.CanSet() {
return fmt.Errorf("field is not exported")
}
- // If we have a time.Duration then use time.ParseDuration
- if v.Type() == durationType {
+ // If we have a nil pointer then allocate a new object
+ if v.Kind() == reflect.Ptr && v.IsNil() {
+ v.Set(reflect.New(v.Type().Elem()))
+ }
+
+ // Get the object as an interface
+ scalar := v.Interface()
+
+ // If it implements encoding.TextUnmarshaler then use that
+ if scalar, ok := scalar.(encoding.TextUnmarshaler); ok {
+ return scalar.UnmarshalText([]byte(s))
+ }
+
+ // If we have a pointer then dereference it
+ if v.Kind() == reflect.Ptr {
+ v = v.Elem()
+ }
+
+ // Switch on concrete type
+ switch scalar.(type) {
+ case time.Duration:
x, err := time.ParseDuration(s)
if err != nil {
return err
@@ -29,6 +43,7 @@ func setScalar(v reflect.Value, s string) error {
return nil
}
+ // Switch on kind so that we can handle derived types
switch v.Kind() {
case reflect.String:
v.SetString(s)
@@ -57,7 +72,7 @@ func setScalar(v reflect.Value, s string) error {
}
v.SetFloat(x)
default:
- return fmt.Errorf("not a scalar type: %s", v.Kind())
+ return fmt.Errorf("cannot parse argument into %s", v.Type().String())
}
return nil
}