summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md92
-rw-r--r--parse.go122
-rw-r--r--parse_test.go261
-rw-r--r--scalar.go145
-rw-r--r--usage.go4
-rw-r--r--usage_test.go7
6 files changed, 559 insertions, 72 deletions
diff --git a/README.md b/README.md
index f4c8d11..28ff388 100644
--- a/README.md
+++ b/README.md
@@ -4,6 +4,10 @@
## Structured argument parsing for Go
+```shell
+go get github.com/alexflint/go-arg
+```
+
Declare the command line arguments your program accepts by defining a struct.
```go
@@ -24,16 +28,16 @@ hello true
```go
var args struct {
- Foo string `arg:"required"`
- Bar bool
+ ID int `arg:"required"`
+ Timeout time.Duration
}
arg.MustParse(&args)
```
```shell
$ ./example
-usage: example --foo FOO [--bar]
-error: --foo is required
+usage: example --id ID [--timeout TIMEOUT]
+error: --id is required
```
### Positional arguments
@@ -54,6 +58,41 @@ Input: src.txt
Output: [x.out y.out z.out]
```
+### Environment variables
+
+```go
+var args struct {
+ Workers int `arg:"env"`
+}
+arg.MustParse(&args)
+fmt.Println("Workers:", args.Workers)
+```
+
+```
+$ WORKERS=4 ./example
+Workers: 4
+```
+
+```
+$ WORKERS=4 ./example --workers=6
+Workers: 6
+```
+
+You can also override the name of the environment variable:
+
+```go
+var args struct {
+ Workers int `arg:"env:NUM_WORKERS"`
+}
+arg.MustParse(&args)
+fmt.Println("Workers:", args.Workers)
+```
+
+```
+$ NUM_WORKERS=4 ./example
+Workers: 4
+```
+
### Usage strings
```go
var args struct {
@@ -126,10 +165,51 @@ usage: samples [--foo FOO] [--bar BAR]
error: you must provide one of --foo and --bar
```
-### Installation
+### Custom parsing
+
+You can implement your own argument parser by implementing `encoding.TextUnmarshaler`:
+
+```go
+package main
+
+import (
+ "fmt"
+ "strings"
+
+ "github.com/alexflint/go-arg"
+)
+
+// Accepts command line arguments of the form "head.tail"
+type NameDotName struct {
+ Head, Tail string
+}
+
+func (n *NameDotName) UnmarshalText(b []byte) error {
+ s := string(b)
+ pos := strings.Index(s, ".")
+ if pos == -1 {
+ return fmt.Errorf("missing period in %s", s)
+ }
+ n.Head = s[:pos]
+ n.Tail = s[pos+1:]
+ return nil
+}
+func main() {
+ var args struct {
+ Name *NameDotName
+ }
+ arg.MustParse(&args)
+ fmt.Printf("%#v\n", args.Name)
+}
+```
```shell
-go get github.com/alexflint/go-arg
+$ ./example --name=foo.bar
+&main.NameDotName{Head:"foo", Tail:"bar"}
+
+$ ./example --name=oops
+usage: example [--name NAME]
+error: error processing --name: missing period in "oops"
```
### Documentation
diff --git a/parse.go b/parse.go
index 32b9b9d..c959656 100644
--- a/parse.go
+++ b/parse.go
@@ -6,7 +6,6 @@ import (
"os"
"path/filepath"
"reflect"
- "strconv"
"strings"
)
@@ -19,8 +18,10 @@ type spec struct {
required bool
positional bool
help string
+ env string
wasPresent bool
- isBool bool
+ boolean bool
+ fieldName string // for generating helpful errors
}
// ErrHelp indicates that -h or --help were provided
@@ -87,31 +88,19 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
}
spec := spec{
- long: strings.ToLower(field.Name),
- dest: v.Field(i),
+ long: strings.ToLower(field.Name),
+ dest: v.Field(i),
+ fieldName: t.Name() + "." + field.Name,
}
- // Get the scalar type for this field
- scalarType := field.Type
- if scalarType.Kind() == reflect.Slice {
- spec.multiple = true
- scalarType = scalarType.Elem()
- if scalarType.Kind() == reflect.Ptr {
- scalarType = scalarType.Elem()
- }
- }
-
- // Check for unsupported types
- switch scalarType.Kind() {
- case reflect.Array, reflect.Chan, reflect.Func, reflect.Interface,
- reflect.Map, reflect.Ptr, reflect.Struct,
- reflect.Complex64, reflect.Complex128:
- return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, scalarType.Kind())
- }
-
- // Specify that it is a bool for usage
- if scalarType.Kind() == reflect.Bool {
- spec.isBool = true
+ // Check whether this field is supported. It's good to do this here rather than
+ // wait until setScalar because it means that a program with invalid argument
+ // fields will always fail regardless of whether the arguments it recieved happend
+ // to exercise those fields.
+ var parseable bool
+ parseable, spec.boolean, spec.multiple = canParse(field.Type)
+ if !parseable {
+ return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, field.Type.String())
}
// Look at the tag
@@ -137,6 +126,13 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
spec.positional = true
case key == "help":
spec.help = value
+ case key == "env":
+ // Use override name if provided
+ if value != "" {
+ spec.env = value
+ } else {
+ spec.env = strings.ToUpper(field.Name)
+ }
default:
return nil, fmt.Errorf("unrecognized tag '%s' on field %s", key, tag)
}
@@ -195,6 +191,15 @@ func process(specs []*spec, args []string) error {
if spec.short != "" {
optionMap[spec.short] = spec
}
+ if spec.env != "" {
+ if value, found := os.LookupEnv(spec.env); found {
+ err := setScalar(spec.dest, value)
+ if err != nil {
+ return fmt.Errorf("error processing environment variable %s: %v", spec.env, err)
+ }
+ spec.wasPresent = true
+ }
+ }
}
// process each string from the command line
@@ -248,7 +253,8 @@ func process(specs []*spec, args []string) error {
}
// if it's a flag and it has no value then set the value to true
- if spec.dest.Kind() == reflect.Bool && value == "" {
+ // use boolean because this takes account of TextUnmarshaler
+ if spec.boolean && value == "" {
value = "true"
}
@@ -329,41 +335,37 @@ func setSlice(dest reflect.Value, values []string) error {
return nil
}
-// set a value from a string
-func setScalar(v reflect.Value, s string) error {
- if !v.CanSet() {
- return fmt.Errorf("field is not exported")
+// canParse returns true if the type can be parsed from a string
+func canParse(t reflect.Type) (parseable, boolean, multiple bool) {
+ parseable, boolean = isScalar(t)
+ if parseable {
+ return
}
- switch v.Kind() {
- case reflect.String:
- v.Set(reflect.ValueOf(s))
- case reflect.Bool:
- x, err := strconv.ParseBool(s)
- if err != nil {
- return err
- }
- v.Set(reflect.ValueOf(x))
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- x, err := strconv.ParseInt(s, 10, v.Type().Bits())
- if err != nil {
- return err
- }
- v.Set(reflect.ValueOf(x).Convert(v.Type()))
- case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
- x, err := strconv.ParseUint(s, 10, v.Type().Bits())
- if err != nil {
- return err
- }
- v.Set(reflect.ValueOf(x).Convert(v.Type()))
- case reflect.Float32, reflect.Float64:
- x, err := strconv.ParseFloat(s, v.Type().Bits())
- if err != nil {
- return err
- }
- v.Set(reflect.ValueOf(x).Convert(v.Type()))
- default:
- return fmt.Errorf("not a scalar type: %s", v.Kind())
+ // Look inside pointer types
+ if t.Kind() == reflect.Ptr {
+ t = t.Elem()
}
- return nil
+ // Look inside slice types
+ if t.Kind() == reflect.Slice {
+ multiple = true
+ t = t.Elem()
+ }
+
+ parseable, boolean = isScalar(t)
+ if parseable {
+ return
+ }
+
+ // Look inside pointer types (again, in case of []*Type)
+ if t.Kind() == reflect.Ptr {
+ t = t.Elem()
+ }
+
+ parseable, boolean = isScalar(t)
+ if parseable {
+ return
+ }
+
+ return false, false, false
}
diff --git a/parse_test.go b/parse_test.go
index 1189588..964c9a7 100644
--- a/parse_test.go
+++ b/parse_test.go
@@ -1,9 +1,12 @@
package arg
import (
+ "net"
+ "net/mail"
"os"
"strings"
"testing"
+ "time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -14,10 +17,14 @@ func parse(cmdline string, dest interface{}) error {
if err != nil {
return err
}
- return p.Parse(strings.Split(cmdline, " "))
+ var parts []string
+ if len(cmdline) > 0 {
+ parts = strings.Split(cmdline, " ")
+ }
+ return p.Parse(parts)
}
-func TestStringSingle(t *testing.T) {
+func TestString(t *testing.T) {
var args struct {
Foo string
}
@@ -26,6 +33,69 @@ func TestStringSingle(t *testing.T) {
assert.Equal(t, "bar", args.Foo)
}
+func TestInt(t *testing.T) {
+ var args struct {
+ Foo int
+ }
+ err := parse("--foo 7", &args)
+ require.NoError(t, err)
+ assert.EqualValues(t, 7, args.Foo)
+}
+
+func TestUint(t *testing.T) {
+ var args struct {
+ Foo uint
+ }
+ err := parse("--foo 7", &args)
+ require.NoError(t, err)
+ assert.EqualValues(t, 7, args.Foo)
+}
+
+func TestFloat(t *testing.T) {
+ var args struct {
+ Foo float32
+ }
+ err := parse("--foo 3.4", &args)
+ require.NoError(t, err)
+ assert.EqualValues(t, 3.4, args.Foo)
+}
+
+func TestDuration(t *testing.T) {
+ var args struct {
+ Foo time.Duration
+ }
+ err := parse("--foo 3ms", &args)
+ require.NoError(t, err)
+ assert.Equal(t, 3*time.Millisecond, args.Foo)
+}
+
+func TestInvalidDuration(t *testing.T) {
+ var args struct {
+ Foo time.Duration
+ }
+ err := parse("--foo xxx", &args)
+ require.Error(t, err)
+}
+
+func TestIntPtr(t *testing.T) {
+ var args struct {
+ Foo *int
+ }
+ err := parse("--foo 123", &args)
+ require.NoError(t, err)
+ require.NotNil(t, args.Foo)
+ assert.Equal(t, 123, *args.Foo)
+}
+
+func TestIntPtrNotPresent(t *testing.T) {
+ var args struct {
+ Foo *int
+ }
+ err := parse("", &args)
+ require.NoError(t, err)
+ assert.Nil(t, args.Foo)
+}
+
func TestMixed(t *testing.T) {
var args struct {
Foo string `arg:"-f"`
@@ -317,6 +387,14 @@ func TestUnsupportedSliceElement(t *testing.T) {
var args struct {
Foo []interface{}
}
+ err := parse("--foo 3", &args)
+ assert.Error(t, err)
+}
+
+func TestUnsupportedSliceElementMissingValue(t *testing.T) {
+ var args struct {
+ Foo []interface{}
+ }
err := parse("--foo", &args)
assert.Error(t, err)
}
@@ -357,3 +435,182 @@ func TestMustParse(t *testing.T) {
assert.Equal(t, "bar", args.Foo)
assert.NotNil(t, parser)
}
+
+func TestEnvironmentVariable(t *testing.T) {
+ var args struct {
+ Foo string `arg:"env"`
+ }
+ os.Setenv("FOO", "bar")
+ os.Args = []string{"example"}
+ MustParse(&args)
+ assert.Equal(t, "bar", args.Foo)
+}
+
+func TestEnvironmentVariableOverrideName(t *testing.T) {
+ var args struct {
+ Foo string `arg:"env:BAZ"`
+ }
+ os.Setenv("BAZ", "bar")
+ os.Args = []string{"example"}
+ MustParse(&args)
+ assert.Equal(t, "bar", args.Foo)
+}
+
+func TestEnvironmentVariableOverrideArgument(t *testing.T) {
+ var args struct {
+ Foo string `arg:"env"`
+ }
+ os.Setenv("FOO", "bar")
+ os.Args = []string{"example", "--foo", "baz"}
+ MustParse(&args)
+ assert.Equal(t, "baz", args.Foo)
+}
+
+func TestEnvironmentVariableError(t *testing.T) {
+ var args struct {
+ Foo int `arg:"env"`
+ }
+ os.Setenv("FOO", "bar")
+ os.Args = []string{"example"}
+ err := Parse(&args)
+ assert.Error(t, err)
+}
+
+func TestEnvironmentVariableRequired(t *testing.T) {
+ var args struct {
+ Foo string `arg:"env,required"`
+ }
+ os.Setenv("FOO", "bar")
+ os.Args = []string{"example"}
+ MustParse(&args)
+ assert.Equal(t, "bar", args.Foo)
+}
+
+type textUnmarshaler struct {
+ val int
+}
+
+func (f *textUnmarshaler) UnmarshalText(b []byte) error {
+ f.val = len(b)
+ return nil
+}
+
+func TestTextUnmarshaler(t *testing.T) {
+ // fields that implement TextUnmarshaler should be parsed using that interface
+ var args struct {
+ Foo *textUnmarshaler
+ }
+ err := parse("--foo abc", &args)
+ require.NoError(t, err)
+ assert.Equal(t, 3, args.Foo.val)
+}
+
+type boolUnmarshaler bool
+
+func (p *boolUnmarshaler) UnmarshalText(b []byte) error {
+ *p = len(b)%2 == 0
+ return nil
+}
+
+func TestBoolUnmarhsaler(t *testing.T) {
+ // test that a bool type that implements TextUnmarshaler is
+ // handled as a TextUnmarshaler not as a bool
+ var args struct {
+ Foo *boolUnmarshaler
+ }
+ err := parse("--foo ab", &args)
+ require.NoError(t, err)
+ assert.EqualValues(t, true, *args.Foo)
+}
+
+type sliceUnmarshaler []int
+
+func (p *sliceUnmarshaler) UnmarshalText(b []byte) error {
+ *p = sliceUnmarshaler{len(b)}
+ return nil
+}
+
+func TestSliceUnmarhsaler(t *testing.T) {
+ // test that a slice type that implements TextUnmarshaler is
+ // handled as a TextUnmarshaler not as a slice
+ var args struct {
+ Foo *sliceUnmarshaler
+ Bar string `arg:"positional"`
+ }
+ err := parse("--foo abcde xyz", &args)
+ require.NoError(t, err)
+ require.Len(t, *args.Foo, 1)
+ assert.EqualValues(t, 5, (*args.Foo)[0])
+ assert.Equal(t, "xyz", args.Bar)
+}
+
+func TestIP(t *testing.T) {
+ var args struct {
+ Host net.IP
+ }
+ err := parse("--host 192.168.0.1", &args)
+ require.NoError(t, err)
+ assert.Equal(t, "192.168.0.1", args.Host.String())
+}
+
+func TestPtrToIP(t *testing.T) {
+ var args struct {
+ Host *net.IP
+ }
+ err := parse("--host 192.168.0.1", &args)
+ require.NoError(t, err)
+ assert.Equal(t, "192.168.0.1", args.Host.String())
+}
+
+func TestIPSlice(t *testing.T) {
+ var args struct {
+ Host []net.IP
+ }
+ err := parse("--host 192.168.0.1 127.0.0.1", &args)
+ require.NoError(t, err)
+ require.Len(t, args.Host, 2)
+ assert.Equal(t, "192.168.0.1", args.Host[0].String())
+ assert.Equal(t, "127.0.0.1", args.Host[1].String())
+}
+
+func TestInvalidIPAddress(t *testing.T) {
+ var args struct {
+ Host net.IP
+ }
+ err := parse("--host xxx", &args)
+ assert.Error(t, err)
+}
+
+func TestMAC(t *testing.T) {
+ var args struct {
+ Host net.HardwareAddr
+ }
+ err := parse("--host 0123.4567.89ab", &args)
+ require.NoError(t, err)
+ assert.Equal(t, "01:23:45:67:89:ab", args.Host.String())
+}
+
+func TestInvalidMac(t *testing.T) {
+ var args struct {
+ Host net.HardwareAddr
+ }
+ err := parse("--host xxx", &args)
+ assert.Error(t, err)
+}
+
+func TestMailAddr(t *testing.T) {
+ var args struct {
+ Recipient mail.Address
+ }
+ err := parse("--recipient [email protected]", &args)
+ require.NoError(t, err)
+ assert.Equal(t, "<[email protected]>", args.Recipient.String())
+}
+
+func TestInvalidMailAddr(t *testing.T) {
+ var args struct {
+ Recipient mail.Address
+ }
+ err := parse("--recipient xxx", &args)
+ assert.Error(t, err)
+}
diff --git a/scalar.go b/scalar.go
new file mode 100644
index 0000000..e79b002
--- /dev/null
+++ b/scalar.go
@@ -0,0 +1,145 @@
+package arg
+
+import (
+ "encoding"
+ "fmt"
+ "net"
+ "net/mail"
+ "reflect"
+ "strconv"
+ "time"
+)
+
+// The reflected form of some special types
+var (
+ textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
+ durationType = reflect.TypeOf(time.Duration(0))
+ mailAddressType = reflect.TypeOf(mail.Address{})
+ ipType = reflect.TypeOf(net.IP{})
+ macType = reflect.TypeOf(net.HardwareAddr{})
+)
+
+// isScalar returns true if the type can be parsed from a single string
+func isScalar(t reflect.Type) (scalar, boolean bool) {
+ // If it implements encoding.TextUnmarshaler then use that
+ if t.Implements(textUnmarshalerType) {
+ // scalar=YES, boolean=NO
+ return true, false
+ }
+
+ // If we have a pointer then dereference it
+ if t.Kind() == reflect.Ptr {
+ t = t.Elem()
+ }
+
+ // Check for other special types
+ switch t {
+ case durationType, mailAddressType, ipType, macType:
+ // scalar=YES, boolean=NO
+ return true, false
+ }
+
+ // Fall back to checking the kind
+ switch t.Kind() {
+ case reflect.Bool:
+ // scalar=YES, boolean=YES
+ return true, true
+ case reflect.String, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
+ reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
+ reflect.Float32, reflect.Float64:
+ // scalar=YES, boolean=NO
+ return true, false
+ }
+ // scalar=NO, boolean=NO
+ return false, false
+}
+
+// set a value from a string
+func setScalar(v reflect.Value, s string) error {
+ if !v.CanSet() {
+ return fmt.Errorf("field is not exported")
+ }
+
+ // If we have a nil pointer then allocate a new object
+ if v.Kind() == reflect.Ptr && v.IsNil() {
+ v.Set(reflect.New(v.Type().Elem()))
+ }
+
+ // Get the object as an interface
+ scalar := v.Interface()
+
+ // If it implements encoding.TextUnmarshaler then use that
+ if scalar, ok := scalar.(encoding.TextUnmarshaler); ok {
+ return scalar.UnmarshalText([]byte(s))
+ }
+
+ // If we have a pointer then dereference it
+ if v.Kind() == reflect.Ptr {
+ v = v.Elem()
+ }
+
+ // Switch on concrete type
+ switch scalar.(type) {
+ case time.Duration:
+ duration, err := time.ParseDuration(s)
+ if err != nil {
+ return err
+ }
+ v.Set(reflect.ValueOf(duration))
+ return nil
+ case mail.Address:
+ addr, err := mail.ParseAddress(s)
+ if err != nil {
+ return err
+ }
+ v.Set(reflect.ValueOf(*addr))
+ return nil
+ case net.IP:
+ ip := net.ParseIP(s)
+ if ip == nil {
+ return fmt.Errorf(`invalid IP address: "%s"`, s)
+ }
+ v.Set(reflect.ValueOf(ip))
+ return nil
+ case net.HardwareAddr:
+ ip, err := net.ParseMAC(s)
+ if err != nil {
+ return err
+ }
+ v.Set(reflect.ValueOf(ip))
+ return nil
+ }
+
+ // Switch on kind so that we can handle derived types
+ switch v.Kind() {
+ case reflect.String:
+ v.SetString(s)
+ case reflect.Bool:
+ x, err := strconv.ParseBool(s)
+ if err != nil {
+ return err
+ }
+ v.SetBool(x)
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ x, err := strconv.ParseInt(s, 10, v.Type().Bits())
+ if err != nil {
+ return err
+ }
+ v.SetInt(x)
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ x, err := strconv.ParseUint(s, 10, v.Type().Bits())
+ if err != nil {
+ return err
+ }
+ v.SetUint(x)
+ case reflect.Float32, reflect.Float64:
+ x, err := strconv.ParseFloat(s, v.Type().Bits())
+ if err != nil {
+ return err
+ }
+ v.SetFloat(x)
+ default:
+ return fmt.Errorf("cannot parse argument into %s", v.Type().String())
+ }
+ return nil
+}
diff --git a/usage.go b/usage.go
index f9d2eb9..9e9ce77 100644
--- a/usage.go
+++ b/usage.go
@@ -96,7 +96,7 @@ func (p *Parser) WriteHelp(w io.Writer) {
}
// write the list of built in options
- printOption(w, &spec{isBool: true, long: "help", short: "h", help: "display this help and exit"})
+ printOption(w, &spec{boolean: true, long: "help", short: "h", help: "display this help and exit"})
}
func printOption(w io.Writer, spec *spec) {
@@ -125,7 +125,7 @@ func printOption(w io.Writer, spec *spec) {
}
func synopsis(spec *spec, form string) string {
- if spec.isBool {
+ if spec.boolean {
return form
}
return form + " " + strings.ToUpper(spec.long)
diff --git a/usage_test.go b/usage_test.go
index 130cd45..fd2ba3a 100644
--- a/usage_test.go
+++ b/usage_test.go
@@ -10,9 +10,9 @@ import (
)
func TestWriteUsage(t *testing.T) {
- expectedUsage := "usage: example [--name NAME] [--value VALUE] [--verbose] [--dataset DATASET] [--optimize OPTIMIZE] [--ids IDS] INPUT [OUTPUT [OUTPUT ...]]\n"
+ expectedUsage := "usage: example [--name NAME] [--value VALUE] [--verbose] [--dataset DATASET] [--optimize OPTIMIZE] [--ids IDS] [--workers WORKERS] INPUT [OUTPUT [OUTPUT ...]]\n"
- expectedHelp := `usage: example [--name NAME] [--value VALUE] [--verbose] [--dataset DATASET] [--optimize OPTIMIZE] [--ids IDS] INPUT [OUTPUT [OUTPUT ...]]
+ expectedHelp := `usage: example [--name NAME] [--value VALUE] [--verbose] [--dataset DATASET] [--optimize OPTIMIZE] [--ids IDS] [--workers WORKERS] INPUT [OUTPUT [OUTPUT ...]]
positional arguments:
input
@@ -26,6 +26,8 @@ options:
--optimize OPTIMIZE, -O OPTIMIZE
optimization level
--ids IDS Ids
+ --workers WORKERS, -w WORKERS
+ number of workers to start
--help, -h display this help and exit
`
var args struct {
@@ -37,6 +39,7 @@ options:
Dataset string `arg:"help:dataset to use"`
Optimize int `arg:"-O,help:optimization level"`
Ids []int64 `arg:"help:Ids"`
+ Workers int `arg:"-w,env:WORKERS,help:number of workers to start"`
}
args.Name = "Foo Bar"
args.Value = 42