diff options
| -rw-r--r-- | README.md | 92 | ||||
| -rw-r--r-- | parse.go | 122 | ||||
| -rw-r--r-- | parse_test.go | 261 | ||||
| -rw-r--r-- | scalar.go | 145 | ||||
| -rw-r--r-- | usage.go | 4 | ||||
| -rw-r--r-- | usage_test.go | 7 |
6 files changed, 559 insertions, 72 deletions
@@ -4,6 +4,10 @@ ## Structured argument parsing for Go +```shell +go get github.com/alexflint/go-arg +``` + Declare the command line arguments your program accepts by defining a struct. ```go @@ -24,16 +28,16 @@ hello true ```go var args struct { - Foo string `arg:"required"` - Bar bool + ID int `arg:"required"` + Timeout time.Duration } arg.MustParse(&args) ``` ```shell $ ./example -usage: example --foo FOO [--bar] -error: --foo is required +usage: example --id ID [--timeout TIMEOUT] +error: --id is required ``` ### Positional arguments @@ -54,6 +58,41 @@ Input: src.txt Output: [x.out y.out z.out] ``` +### Environment variables + +```go +var args struct { + Workers int `arg:"env"` +} +arg.MustParse(&args) +fmt.Println("Workers:", args.Workers) +``` + +``` +$ WORKERS=4 ./example +Workers: 4 +``` + +``` +$ WORKERS=4 ./example --workers=6 +Workers: 6 +``` + +You can also override the name of the environment variable: + +```go +var args struct { + Workers int `arg:"env:NUM_WORKERS"` +} +arg.MustParse(&args) +fmt.Println("Workers:", args.Workers) +``` + +``` +$ NUM_WORKERS=4 ./example +Workers: 4 +``` + ### Usage strings ```go var args struct { @@ -126,10 +165,51 @@ usage: samples [--foo FOO] [--bar BAR] error: you must provide one of --foo and --bar ``` -### Installation +### Custom parsing + +You can implement your own argument parser by implementing `encoding.TextUnmarshaler`: + +```go +package main + +import ( + "fmt" + "strings" + + "github.com/alexflint/go-arg" +) + +// Accepts command line arguments of the form "head.tail" +type NameDotName struct { + Head, Tail string +} + +func (n *NameDotName) UnmarshalText(b []byte) error { + s := string(b) + pos := strings.Index(s, ".") + if pos == -1 { + return fmt.Errorf("missing period in %s", s) + } + n.Head = s[:pos] + n.Tail = s[pos+1:] + return nil +} +func main() { + var args struct { + Name *NameDotName + } + arg.MustParse(&args) + fmt.Printf("%#v\n", args.Name) +} +``` ```shell -go get github.com/alexflint/go-arg +$ ./example --name=foo.bar +&main.NameDotName{Head:"foo", Tail:"bar"} + +$ ./example --name=oops +usage: example [--name NAME] +error: error processing --name: missing period in "oops" ``` ### Documentation @@ -6,7 +6,6 @@ import ( "os" "path/filepath" "reflect" - "strconv" "strings" ) @@ -19,8 +18,10 @@ type spec struct { required bool positional bool help string + env string wasPresent bool - isBool bool + boolean bool + fieldName string // for generating helpful errors } // ErrHelp indicates that -h or --help were provided @@ -87,31 +88,19 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) { } spec := spec{ - long: strings.ToLower(field.Name), - dest: v.Field(i), + long: strings.ToLower(field.Name), + dest: v.Field(i), + fieldName: t.Name() + "." + field.Name, } - // Get the scalar type for this field - scalarType := field.Type - if scalarType.Kind() == reflect.Slice { - spec.multiple = true - scalarType = scalarType.Elem() - if scalarType.Kind() == reflect.Ptr { - scalarType = scalarType.Elem() - } - } - - // Check for unsupported types - switch scalarType.Kind() { - case reflect.Array, reflect.Chan, reflect.Func, reflect.Interface, - reflect.Map, reflect.Ptr, reflect.Struct, - reflect.Complex64, reflect.Complex128: - return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, scalarType.Kind()) - } - - // Specify that it is a bool for usage - if scalarType.Kind() == reflect.Bool { - spec.isBool = true + // Check whether this field is supported. It's good to do this here rather than + // wait until setScalar because it means that a program with invalid argument + // fields will always fail regardless of whether the arguments it recieved happend + // to exercise those fields. + var parseable bool + parseable, spec.boolean, spec.multiple = canParse(field.Type) + if !parseable { + return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, field.Type.String()) } // Look at the tag @@ -137,6 +126,13 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) { spec.positional = true case key == "help": spec.help = value + case key == "env": + // Use override name if provided + if value != "" { + spec.env = value + } else { + spec.env = strings.ToUpper(field.Name) + } default: return nil, fmt.Errorf("unrecognized tag '%s' on field %s", key, tag) } @@ -195,6 +191,15 @@ func process(specs []*spec, args []string) error { if spec.short != "" { optionMap[spec.short] = spec } + if spec.env != "" { + if value, found := os.LookupEnv(spec.env); found { + err := setScalar(spec.dest, value) + if err != nil { + return fmt.Errorf("error processing environment variable %s: %v", spec.env, err) + } + spec.wasPresent = true + } + } } // process each string from the command line @@ -248,7 +253,8 @@ func process(specs []*spec, args []string) error { } // if it's a flag and it has no value then set the value to true - if spec.dest.Kind() == reflect.Bool && value == "" { + // use boolean because this takes account of TextUnmarshaler + if spec.boolean && value == "" { value = "true" } @@ -329,41 +335,37 @@ func setSlice(dest reflect.Value, values []string) error { return nil } -// set a value from a string -func setScalar(v reflect.Value, s string) error { - if !v.CanSet() { - return fmt.Errorf("field is not exported") +// canParse returns true if the type can be parsed from a string +func canParse(t reflect.Type) (parseable, boolean, multiple bool) { + parseable, boolean = isScalar(t) + if parseable { + return } - switch v.Kind() { - case reflect.String: - v.Set(reflect.ValueOf(s)) - case reflect.Bool: - x, err := strconv.ParseBool(s) - if err != nil { - return err - } - v.Set(reflect.ValueOf(x)) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - x, err := strconv.ParseInt(s, 10, v.Type().Bits()) - if err != nil { - return err - } - v.Set(reflect.ValueOf(x).Convert(v.Type())) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - x, err := strconv.ParseUint(s, 10, v.Type().Bits()) - if err != nil { - return err - } - v.Set(reflect.ValueOf(x).Convert(v.Type())) - case reflect.Float32, reflect.Float64: - x, err := strconv.ParseFloat(s, v.Type().Bits()) - if err != nil { - return err - } - v.Set(reflect.ValueOf(x).Convert(v.Type())) - default: - return fmt.Errorf("not a scalar type: %s", v.Kind()) + // Look inside pointer types + if t.Kind() == reflect.Ptr { + t = t.Elem() } - return nil + // Look inside slice types + if t.Kind() == reflect.Slice { + multiple = true + t = t.Elem() + } + + parseable, boolean = isScalar(t) + if parseable { + return + } + + // Look inside pointer types (again, in case of []*Type) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + parseable, boolean = isScalar(t) + if parseable { + return + } + + return false, false, false } diff --git a/parse_test.go b/parse_test.go index 1189588..964c9a7 100644 --- a/parse_test.go +++ b/parse_test.go @@ -1,9 +1,12 @@ package arg import ( + "net" + "net/mail" "os" "strings" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -14,10 +17,14 @@ func parse(cmdline string, dest interface{}) error { if err != nil { return err } - return p.Parse(strings.Split(cmdline, " ")) + var parts []string + if len(cmdline) > 0 { + parts = strings.Split(cmdline, " ") + } + return p.Parse(parts) } -func TestStringSingle(t *testing.T) { +func TestString(t *testing.T) { var args struct { Foo string } @@ -26,6 +33,69 @@ func TestStringSingle(t *testing.T) { assert.Equal(t, "bar", args.Foo) } +func TestInt(t *testing.T) { + var args struct { + Foo int + } + err := parse("--foo 7", &args) + require.NoError(t, err) + assert.EqualValues(t, 7, args.Foo) +} + +func TestUint(t *testing.T) { + var args struct { + Foo uint + } + err := parse("--foo 7", &args) + require.NoError(t, err) + assert.EqualValues(t, 7, args.Foo) +} + +func TestFloat(t *testing.T) { + var args struct { + Foo float32 + } + err := parse("--foo 3.4", &args) + require.NoError(t, err) + assert.EqualValues(t, 3.4, args.Foo) +} + +func TestDuration(t *testing.T) { + var args struct { + Foo time.Duration + } + err := parse("--foo 3ms", &args) + require.NoError(t, err) + assert.Equal(t, 3*time.Millisecond, args.Foo) +} + +func TestInvalidDuration(t *testing.T) { + var args struct { + Foo time.Duration + } + err := parse("--foo xxx", &args) + require.Error(t, err) +} + +func TestIntPtr(t *testing.T) { + var args struct { + Foo *int + } + err := parse("--foo 123", &args) + require.NoError(t, err) + require.NotNil(t, args.Foo) + assert.Equal(t, 123, *args.Foo) +} + +func TestIntPtrNotPresent(t *testing.T) { + var args struct { + Foo *int + } + err := parse("", &args) + require.NoError(t, err) + assert.Nil(t, args.Foo) +} + func TestMixed(t *testing.T) { var args struct { Foo string `arg:"-f"` @@ -317,6 +387,14 @@ func TestUnsupportedSliceElement(t *testing.T) { var args struct { Foo []interface{} } + err := parse("--foo 3", &args) + assert.Error(t, err) +} + +func TestUnsupportedSliceElementMissingValue(t *testing.T) { + var args struct { + Foo []interface{} + } err := parse("--foo", &args) assert.Error(t, err) } @@ -357,3 +435,182 @@ func TestMustParse(t *testing.T) { assert.Equal(t, "bar", args.Foo) assert.NotNil(t, parser) } + +func TestEnvironmentVariable(t *testing.T) { + var args struct { + Foo string `arg:"env"` + } + os.Setenv("FOO", "bar") + os.Args = []string{"example"} + MustParse(&args) + assert.Equal(t, "bar", args.Foo) +} + +func TestEnvironmentVariableOverrideName(t *testing.T) { + var args struct { + Foo string `arg:"env:BAZ"` + } + os.Setenv("BAZ", "bar") + os.Args = []string{"example"} + MustParse(&args) + assert.Equal(t, "bar", args.Foo) +} + +func TestEnvironmentVariableOverrideArgument(t *testing.T) { + var args struct { + Foo string `arg:"env"` + } + os.Setenv("FOO", "bar") + os.Args = []string{"example", "--foo", "baz"} + MustParse(&args) + assert.Equal(t, "baz", args.Foo) +} + +func TestEnvironmentVariableError(t *testing.T) { + var args struct { + Foo int `arg:"env"` + } + os.Setenv("FOO", "bar") + os.Args = []string{"example"} + err := Parse(&args) + assert.Error(t, err) +} + +func TestEnvironmentVariableRequired(t *testing.T) { + var args struct { + Foo string `arg:"env,required"` + } + os.Setenv("FOO", "bar") + os.Args = []string{"example"} + MustParse(&args) + assert.Equal(t, "bar", args.Foo) +} + +type textUnmarshaler struct { + val int +} + +func (f *textUnmarshaler) UnmarshalText(b []byte) error { + f.val = len(b) + return nil +} + +func TestTextUnmarshaler(t *testing.T) { + // fields that implement TextUnmarshaler should be parsed using that interface + var args struct { + Foo *textUnmarshaler + } + err := parse("--foo abc", &args) + require.NoError(t, err) + assert.Equal(t, 3, args.Foo.val) +} + +type boolUnmarshaler bool + +func (p *boolUnmarshaler) UnmarshalText(b []byte) error { + *p = len(b)%2 == 0 + return nil +} + +func TestBoolUnmarhsaler(t *testing.T) { + // test that a bool type that implements TextUnmarshaler is + // handled as a TextUnmarshaler not as a bool + var args struct { + Foo *boolUnmarshaler + } + err := parse("--foo ab", &args) + require.NoError(t, err) + assert.EqualValues(t, true, *args.Foo) +} + +type sliceUnmarshaler []int + +func (p *sliceUnmarshaler) UnmarshalText(b []byte) error { + *p = sliceUnmarshaler{len(b)} + return nil +} + +func TestSliceUnmarhsaler(t *testing.T) { + // test that a slice type that implements TextUnmarshaler is + // handled as a TextUnmarshaler not as a slice + var args struct { + Foo *sliceUnmarshaler + Bar string `arg:"positional"` + } + err := parse("--foo abcde xyz", &args) + require.NoError(t, err) + require.Len(t, *args.Foo, 1) + assert.EqualValues(t, 5, (*args.Foo)[0]) + assert.Equal(t, "xyz", args.Bar) +} + +func TestIP(t *testing.T) { + var args struct { + Host net.IP + } + err := parse("--host 192.168.0.1", &args) + require.NoError(t, err) + assert.Equal(t, "192.168.0.1", args.Host.String()) +} + +func TestPtrToIP(t *testing.T) { + var args struct { + Host *net.IP + } + err := parse("--host 192.168.0.1", &args) + require.NoError(t, err) + assert.Equal(t, "192.168.0.1", args.Host.String()) +} + +func TestIPSlice(t *testing.T) { + var args struct { + Host []net.IP + } + err := parse("--host 192.168.0.1 127.0.0.1", &args) + require.NoError(t, err) + require.Len(t, args.Host, 2) + assert.Equal(t, "192.168.0.1", args.Host[0].String()) + assert.Equal(t, "127.0.0.1", args.Host[1].String()) +} + +func TestInvalidIPAddress(t *testing.T) { + var args struct { + Host net.IP + } + err := parse("--host xxx", &args) + assert.Error(t, err) +} + +func TestMAC(t *testing.T) { + var args struct { + Host net.HardwareAddr + } + err := parse("--host 0123.4567.89ab", &args) + require.NoError(t, err) + assert.Equal(t, "01:23:45:67:89:ab", args.Host.String()) +} + +func TestInvalidMac(t *testing.T) { + var args struct { + Host net.HardwareAddr + } + err := parse("--host xxx", &args) + assert.Error(t, err) +} + +func TestMailAddr(t *testing.T) { + var args struct { + Recipient mail.Address + } + err := parse("--recipient [email protected]", &args) + require.NoError(t, err) + assert.Equal(t, "<[email protected]>", args.Recipient.String()) +} + +func TestInvalidMailAddr(t *testing.T) { + var args struct { + Recipient mail.Address + } + err := parse("--recipient xxx", &args) + assert.Error(t, err) +} diff --git a/scalar.go b/scalar.go new file mode 100644 index 0000000..e79b002 --- /dev/null +++ b/scalar.go @@ -0,0 +1,145 @@ +package arg + +import ( + "encoding" + "fmt" + "net" + "net/mail" + "reflect" + "strconv" + "time" +) + +// The reflected form of some special types +var ( + textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem() + durationType = reflect.TypeOf(time.Duration(0)) + mailAddressType = reflect.TypeOf(mail.Address{}) + ipType = reflect.TypeOf(net.IP{}) + macType = reflect.TypeOf(net.HardwareAddr{}) +) + +// isScalar returns true if the type can be parsed from a single string +func isScalar(t reflect.Type) (scalar, boolean bool) { + // If it implements encoding.TextUnmarshaler then use that + if t.Implements(textUnmarshalerType) { + // scalar=YES, boolean=NO + return true, false + } + + // If we have a pointer then dereference it + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + // Check for other special types + switch t { + case durationType, mailAddressType, ipType, macType: + // scalar=YES, boolean=NO + return true, false + } + + // Fall back to checking the kind + switch t.Kind() { + case reflect.Bool: + // scalar=YES, boolean=YES + return true, true + case reflect.String, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + // scalar=YES, boolean=NO + return true, false + } + // scalar=NO, boolean=NO + return false, false +} + +// set a value from a string +func setScalar(v reflect.Value, s string) error { + if !v.CanSet() { + return fmt.Errorf("field is not exported") + } + + // If we have a nil pointer then allocate a new object + if v.Kind() == reflect.Ptr && v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + + // Get the object as an interface + scalar := v.Interface() + + // If it implements encoding.TextUnmarshaler then use that + if scalar, ok := scalar.(encoding.TextUnmarshaler); ok { + return scalar.UnmarshalText([]byte(s)) + } + + // If we have a pointer then dereference it + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + // Switch on concrete type + switch scalar.(type) { + case time.Duration: + duration, err := time.ParseDuration(s) + if err != nil { + return err + } + v.Set(reflect.ValueOf(duration)) + return nil + case mail.Address: + addr, err := mail.ParseAddress(s) + if err != nil { + return err + } + v.Set(reflect.ValueOf(*addr)) + return nil + case net.IP: + ip := net.ParseIP(s) + if ip == nil { + return fmt.Errorf(`invalid IP address: "%s"`, s) + } + v.Set(reflect.ValueOf(ip)) + return nil + case net.HardwareAddr: + ip, err := net.ParseMAC(s) + if err != nil { + return err + } + v.Set(reflect.ValueOf(ip)) + return nil + } + + // Switch on kind so that we can handle derived types + switch v.Kind() { + case reflect.String: + v.SetString(s) + case reflect.Bool: + x, err := strconv.ParseBool(s) + if err != nil { + return err + } + v.SetBool(x) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + x, err := strconv.ParseInt(s, 10, v.Type().Bits()) + if err != nil { + return err + } + v.SetInt(x) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + x, err := strconv.ParseUint(s, 10, v.Type().Bits()) + if err != nil { + return err + } + v.SetUint(x) + case reflect.Float32, reflect.Float64: + x, err := strconv.ParseFloat(s, v.Type().Bits()) + if err != nil { + return err + } + v.SetFloat(x) + default: + return fmt.Errorf("cannot parse argument into %s", v.Type().String()) + } + return nil +} @@ -96,7 +96,7 @@ func (p *Parser) WriteHelp(w io.Writer) { } // write the list of built in options - printOption(w, &spec{isBool: true, long: "help", short: "h", help: "display this help and exit"}) + printOption(w, &spec{boolean: true, long: "help", short: "h", help: "display this help and exit"}) } func printOption(w io.Writer, spec *spec) { @@ -125,7 +125,7 @@ func printOption(w io.Writer, spec *spec) { } func synopsis(spec *spec, form string) string { - if spec.isBool { + if spec.boolean { return form } return form + " " + strings.ToUpper(spec.long) diff --git a/usage_test.go b/usage_test.go index 130cd45..fd2ba3a 100644 --- a/usage_test.go +++ b/usage_test.go @@ -10,9 +10,9 @@ import ( ) func TestWriteUsage(t *testing.T) { - expectedUsage := "usage: example [--name NAME] [--value VALUE] [--verbose] [--dataset DATASET] [--optimize OPTIMIZE] [--ids IDS] INPUT [OUTPUT [OUTPUT ...]]\n" + expectedUsage := "usage: example [--name NAME] [--value VALUE] [--verbose] [--dataset DATASET] [--optimize OPTIMIZE] [--ids IDS] [--workers WORKERS] INPUT [OUTPUT [OUTPUT ...]]\n" - expectedHelp := `usage: example [--name NAME] [--value VALUE] [--verbose] [--dataset DATASET] [--optimize OPTIMIZE] [--ids IDS] INPUT [OUTPUT [OUTPUT ...]] + expectedHelp := `usage: example [--name NAME] [--value VALUE] [--verbose] [--dataset DATASET] [--optimize OPTIMIZE] [--ids IDS] [--workers WORKERS] INPUT [OUTPUT [OUTPUT ...]] positional arguments: input @@ -26,6 +26,8 @@ options: --optimize OPTIMIZE, -O OPTIMIZE optimization level --ids IDS Ids + --workers WORKERS, -w WORKERS + number of workers to start --help, -h display this help and exit ` var args struct { @@ -37,6 +39,7 @@ options: Dataset string `arg:"help:dataset to use"` Optimize int `arg:"-O,help:optimization level"` Ids []int64 `arg:"help:Ids"` + Workers int `arg:"-w,env:WORKERS,help:number of workers to start"` } args.Name = "Foo Bar" args.Value = 42 |
