summaryrefslogtreecommitdiff
path: root/parse_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'parse_test.go')
-rw-r--r--parse_test.go261
1 files changed, 259 insertions, 2 deletions
diff --git a/parse_test.go b/parse_test.go
index 1189588..964c9a7 100644
--- a/parse_test.go
+++ b/parse_test.go
@@ -1,9 +1,12 @@
package arg
import (
+ "net"
+ "net/mail"
"os"
"strings"
"testing"
+ "time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -14,10 +17,14 @@ func parse(cmdline string, dest interface{}) error {
if err != nil {
return err
}
- return p.Parse(strings.Split(cmdline, " "))
+ var parts []string
+ if len(cmdline) > 0 {
+ parts = strings.Split(cmdline, " ")
+ }
+ return p.Parse(parts)
}
-func TestStringSingle(t *testing.T) {
+func TestString(t *testing.T) {
var args struct {
Foo string
}
@@ -26,6 +33,69 @@ func TestStringSingle(t *testing.T) {
assert.Equal(t, "bar", args.Foo)
}
+func TestInt(t *testing.T) {
+ var args struct {
+ Foo int
+ }
+ err := parse("--foo 7", &args)
+ require.NoError(t, err)
+ assert.EqualValues(t, 7, args.Foo)
+}
+
+func TestUint(t *testing.T) {
+ var args struct {
+ Foo uint
+ }
+ err := parse("--foo 7", &args)
+ require.NoError(t, err)
+ assert.EqualValues(t, 7, args.Foo)
+}
+
+func TestFloat(t *testing.T) {
+ var args struct {
+ Foo float32
+ }
+ err := parse("--foo 3.4", &args)
+ require.NoError(t, err)
+ assert.EqualValues(t, 3.4, args.Foo)
+}
+
+func TestDuration(t *testing.T) {
+ var args struct {
+ Foo time.Duration
+ }
+ err := parse("--foo 3ms", &args)
+ require.NoError(t, err)
+ assert.Equal(t, 3*time.Millisecond, args.Foo)
+}
+
+func TestInvalidDuration(t *testing.T) {
+ var args struct {
+ Foo time.Duration
+ }
+ err := parse("--foo xxx", &args)
+ require.Error(t, err)
+}
+
+func TestIntPtr(t *testing.T) {
+ var args struct {
+ Foo *int
+ }
+ err := parse("--foo 123", &args)
+ require.NoError(t, err)
+ require.NotNil(t, args.Foo)
+ assert.Equal(t, 123, *args.Foo)
+}
+
+func TestIntPtrNotPresent(t *testing.T) {
+ var args struct {
+ Foo *int
+ }
+ err := parse("", &args)
+ require.NoError(t, err)
+ assert.Nil(t, args.Foo)
+}
+
func TestMixed(t *testing.T) {
var args struct {
Foo string `arg:"-f"`
@@ -317,6 +387,14 @@ func TestUnsupportedSliceElement(t *testing.T) {
var args struct {
Foo []interface{}
}
+ err := parse("--foo 3", &args)
+ assert.Error(t, err)
+}
+
+func TestUnsupportedSliceElementMissingValue(t *testing.T) {
+ var args struct {
+ Foo []interface{}
+ }
err := parse("--foo", &args)
assert.Error(t, err)
}
@@ -357,3 +435,182 @@ func TestMustParse(t *testing.T) {
assert.Equal(t, "bar", args.Foo)
assert.NotNil(t, parser)
}
+
+func TestEnvironmentVariable(t *testing.T) {
+ var args struct {
+ Foo string `arg:"env"`
+ }
+ os.Setenv("FOO", "bar")
+ os.Args = []string{"example"}
+ MustParse(&args)
+ assert.Equal(t, "bar", args.Foo)
+}
+
+func TestEnvironmentVariableOverrideName(t *testing.T) {
+ var args struct {
+ Foo string `arg:"env:BAZ"`
+ }
+ os.Setenv("BAZ", "bar")
+ os.Args = []string{"example"}
+ MustParse(&args)
+ assert.Equal(t, "bar", args.Foo)
+}
+
+func TestEnvironmentVariableOverrideArgument(t *testing.T) {
+ var args struct {
+ Foo string `arg:"env"`
+ }
+ os.Setenv("FOO", "bar")
+ os.Args = []string{"example", "--foo", "baz"}
+ MustParse(&args)
+ assert.Equal(t, "baz", args.Foo)
+}
+
+func TestEnvironmentVariableError(t *testing.T) {
+ var args struct {
+ Foo int `arg:"env"`
+ }
+ os.Setenv("FOO", "bar")
+ os.Args = []string{"example"}
+ err := Parse(&args)
+ assert.Error(t, err)
+}
+
+func TestEnvironmentVariableRequired(t *testing.T) {
+ var args struct {
+ Foo string `arg:"env,required"`
+ }
+ os.Setenv("FOO", "bar")
+ os.Args = []string{"example"}
+ MustParse(&args)
+ assert.Equal(t, "bar", args.Foo)
+}
+
+type textUnmarshaler struct {
+ val int
+}
+
+func (f *textUnmarshaler) UnmarshalText(b []byte) error {
+ f.val = len(b)
+ return nil
+}
+
+func TestTextUnmarshaler(t *testing.T) {
+ // fields that implement TextUnmarshaler should be parsed using that interface
+ var args struct {
+ Foo *textUnmarshaler
+ }
+ err := parse("--foo abc", &args)
+ require.NoError(t, err)
+ assert.Equal(t, 3, args.Foo.val)
+}
+
+type boolUnmarshaler bool
+
+func (p *boolUnmarshaler) UnmarshalText(b []byte) error {
+ *p = len(b)%2 == 0
+ return nil
+}
+
+func TestBoolUnmarhsaler(t *testing.T) {
+ // test that a bool type that implements TextUnmarshaler is
+ // handled as a TextUnmarshaler not as a bool
+ var args struct {
+ Foo *boolUnmarshaler
+ }
+ err := parse("--foo ab", &args)
+ require.NoError(t, err)
+ assert.EqualValues(t, true, *args.Foo)
+}
+
+type sliceUnmarshaler []int
+
+func (p *sliceUnmarshaler) UnmarshalText(b []byte) error {
+ *p = sliceUnmarshaler{len(b)}
+ return nil
+}
+
+func TestSliceUnmarhsaler(t *testing.T) {
+ // test that a slice type that implements TextUnmarshaler is
+ // handled as a TextUnmarshaler not as a slice
+ var args struct {
+ Foo *sliceUnmarshaler
+ Bar string `arg:"positional"`
+ }
+ err := parse("--foo abcde xyz", &args)
+ require.NoError(t, err)
+ require.Len(t, *args.Foo, 1)
+ assert.EqualValues(t, 5, (*args.Foo)[0])
+ assert.Equal(t, "xyz", args.Bar)
+}
+
+func TestIP(t *testing.T) {
+ var args struct {
+ Host net.IP
+ }
+ err := parse("--host 192.168.0.1", &args)
+ require.NoError(t, err)
+ assert.Equal(t, "192.168.0.1", args.Host.String())
+}
+
+func TestPtrToIP(t *testing.T) {
+ var args struct {
+ Host *net.IP
+ }
+ err := parse("--host 192.168.0.1", &args)
+ require.NoError(t, err)
+ assert.Equal(t, "192.168.0.1", args.Host.String())
+}
+
+func TestIPSlice(t *testing.T) {
+ var args struct {
+ Host []net.IP
+ }
+ err := parse("--host 192.168.0.1 127.0.0.1", &args)
+ require.NoError(t, err)
+ require.Len(t, args.Host, 2)
+ assert.Equal(t, "192.168.0.1", args.Host[0].String())
+ assert.Equal(t, "127.0.0.1", args.Host[1].String())
+}
+
+func TestInvalidIPAddress(t *testing.T) {
+ var args struct {
+ Host net.IP
+ }
+ err := parse("--host xxx", &args)
+ assert.Error(t, err)
+}
+
+func TestMAC(t *testing.T) {
+ var args struct {
+ Host net.HardwareAddr
+ }
+ err := parse("--host 0123.4567.89ab", &args)
+ require.NoError(t, err)
+ assert.Equal(t, "01:23:45:67:89:ab", args.Host.String())
+}
+
+func TestInvalidMac(t *testing.T) {
+ var args struct {
+ Host net.HardwareAddr
+ }
+ err := parse("--host xxx", &args)
+ assert.Error(t, err)
+}
+
+func TestMailAddr(t *testing.T) {
+ var args struct {
+ Recipient mail.Address
+ }
+ err := parse("--recipient [email protected]", &args)
+ require.NoError(t, err)
+ assert.Equal(t, "<[email protected]>", args.Recipient.String())
+}
+
+func TestInvalidMailAddr(t *testing.T) {
+ var args struct {
+ Recipient mail.Address
+ }
+ err := parse("--recipient xxx", &args)
+ assert.Error(t, err)
+}