summaryrefslogtreecommitdiff
path: root/sequence_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'sequence_test.go')
-rw-r--r--sequence_test.go152
1 files changed, 152 insertions, 0 deletions
diff --git a/sequence_test.go b/sequence_test.go
new file mode 100644
index 0000000..fde3e3a
--- /dev/null
+++ b/sequence_test.go
@@ -0,0 +1,152 @@
+package arg
+
+import (
+ "reflect"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestSetSliceWithoutClearing(t *testing.T) {
+ xs := []int{10}
+ entries := []string{"1", "2", "3"}
+ err := setSlice(reflect.ValueOf(&xs).Elem(), entries, false)
+ require.NoError(t, err)
+ assert.Equal(t, []int{10, 1, 2, 3}, xs)
+}
+
+func TestSetSliceAfterClearing(t *testing.T) {
+ xs := []int{100}
+ entries := []string{"1", "2", "3"}
+ err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true)
+ require.NoError(t, err)
+ assert.Equal(t, []int{1, 2, 3}, xs)
+}
+
+func TestSetSliceInvalid(t *testing.T) {
+ xs := []int{100}
+ entries := []string{"invalid"}
+ err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true)
+ assert.Error(t, err)
+}
+
+func TestSetSlicePtr(t *testing.T) {
+ var xs []*int
+ entries := []string{"1", "2", "3"}
+ err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true)
+ require.NoError(t, err)
+ require.Len(t, xs, 3)
+ assert.Equal(t, 1, *xs[0])
+ assert.Equal(t, 2, *xs[1])
+ assert.Equal(t, 3, *xs[2])
+}
+
+func TestSetSliceTextUnmarshaller(t *testing.T) {
+ // textUnmarshaler is a struct that captures the length of the string passed to it
+ var xs []*textUnmarshaler
+ entries := []string{"a", "aa", "aaa"}
+ err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true)
+ require.NoError(t, err)
+ require.Len(t, xs, 3)
+ assert.Equal(t, 1, xs[0].val)
+ assert.Equal(t, 2, xs[1].val)
+ assert.Equal(t, 3, xs[2].val)
+}
+
+func TestSetMapWithoutClearing(t *testing.T) {
+ m := map[string]int{"foo": 10}
+ entries := []string{"a=1", "b=2"}
+ err := setMap(reflect.ValueOf(&m).Elem(), entries, false)
+ require.NoError(t, err)
+ require.Len(t, m, 3)
+ assert.Equal(t, 1, m["a"])
+ assert.Equal(t, 2, m["b"])
+ assert.Equal(t, 10, m["foo"])
+}
+
+func TestSetMapAfterClearing(t *testing.T) {
+ m := map[string]int{"foo": 10}
+ entries := []string{"a=1", "b=2"}
+ err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
+ require.NoError(t, err)
+ require.Len(t, m, 2)
+ assert.Equal(t, 1, m["a"])
+ assert.Equal(t, 2, m["b"])
+}
+
+func TestSetMapWithKeyPointer(t *testing.T) {
+ // textUnmarshaler is a struct that captures the length of the string passed to it
+ var m map[*string]int
+ entries := []string{"abc=123"}
+ err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
+ require.NoError(t, err)
+ require.Len(t, m, 1)
+}
+
+func TestSetMapWithValuePointer(t *testing.T) {
+ // textUnmarshaler is a struct that captures the length of the string passed to it
+ var m map[string]*int
+ entries := []string{"abc=123"}
+ err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
+ require.NoError(t, err)
+ require.Len(t, m, 1)
+ assert.Equal(t, 123, *m["abc"])
+}
+
+func TestSetMapTextUnmarshaller(t *testing.T) {
+ // textUnmarshaler is a struct that captures the length of the string passed to it
+ var m map[textUnmarshaler]*textUnmarshaler
+ entries := []string{"a=123", "aa=12", "aaa=1"}
+ err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
+ require.NoError(t, err)
+ require.Len(t, m, 3)
+ assert.Equal(t, &textUnmarshaler{3}, m[textUnmarshaler{1}])
+ assert.Equal(t, &textUnmarshaler{2}, m[textUnmarshaler{2}])
+ assert.Equal(t, &textUnmarshaler{1}, m[textUnmarshaler{3}])
+}
+
+func TestSetMapInvalidKey(t *testing.T) {
+ var m map[int]int
+ entries := []string{"invalid=123"}
+ err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
+ assert.Error(t, err)
+}
+
+func TestSetMapInvalidValue(t *testing.T) {
+ var m map[int]int
+ entries := []string{"123=invalid"}
+ err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
+ assert.Error(t, err)
+}
+
+func TestSetMapMalformed(t *testing.T) {
+ // textUnmarshaler is a struct that captures the length of the string passed to it
+ var m map[string]string
+ entries := []string{"missing_equals_sign"}
+ err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
+ assert.Error(t, err)
+}
+
+func TestSetSliceOrMapErrors(t *testing.T) {
+ var err error
+ var dest reflect.Value
+
+ // converting a slice to a reflect.Value in this way will make it read only
+ var cannotSet []int
+ dest = reflect.ValueOf(cannotSet)
+ err = setSliceOrMap(dest, nil, false)
+ assert.Error(t, err)
+
+ // check what happens when we pass in something that is not a slice or a map
+ var notSliceOrMap string
+ dest = reflect.ValueOf(&notSliceOrMap).Elem()
+ err = setSliceOrMap(dest, nil, false)
+ assert.Error(t, err)
+
+ // check what happens when we pass in a pointer to something that is not a slice or a map
+ var stringPtr *string
+ dest = reflect.ValueOf(&stringPtr).Elem()
+ err = setSliceOrMap(dest, nil, false)
+ assert.Error(t, err)
+}