summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compflag/compflag.go58
-rw-r--r--compflag/compflag_test.go45
2 files changed, 97 insertions, 6 deletions
diff --git a/compflag/compflag.go b/compflag/compflag.go
index 9128255..fa219c8 100644
--- a/compflag/compflag.go
+++ b/compflag/compflag.go
@@ -39,6 +39,7 @@ import (
"fmt"
"os"
"strconv"
+ "time"
"github.com/posener/complete/v2"
"github.com/posener/complete/v2/predict"
@@ -75,6 +76,12 @@ func (fs *FlagSet) Int(name string, value int, usage string, options ...predict.
return p
}
+func (fs *FlagSet) Duration(name string, value time.Duration, usage string, options ...predict.Option) *time.Duration {
+ p := new(time.Duration)
+ (*flag.FlagSet)(fs).Var(newDurationValue(value, p, predict.Options(options...)), name, usage)
+ return p
+}
+
var CommandLine = (*FlagSet)(flag.CommandLine)
// Parse parses command line arguments. It also performs bash completion when needed.
@@ -95,6 +102,10 @@ func Int(name string, value int, usage string, options ...predict.Option) *int {
return CommandLine.Int(name, value, usage, options...)
}
+func Duration(name string, value time.Duration, usage string, options ...predict.Option) *time.Duration {
+ return CommandLine.Duration(name, value, usage, options...)
+}
+
type boolValue struct {
v *bool
predict.Config
@@ -114,13 +125,13 @@ func (b *boolValue) Set(val string) error {
return b.Check(val)
}
-func (b *boolValue) Get() interface{} { return bool(*b.v) }
+func (b *boolValue) Get() interface{} { return *b.v }
func (b *boolValue) String() string {
if b == nil || b.v == nil {
return strconv.FormatBool(false)
}
- return strconv.FormatBool(bool(*b.v))
+ return strconv.FormatBool(*b.v)
}
func (b *boolValue) IsBoolFlag() bool { return true }
@@ -154,14 +165,14 @@ func (s *stringValue) Set(val string) error {
}
func (s *stringValue) Get() interface{} {
- return string(*s.v)
+ return *s.v
}
func (s *stringValue) String() string {
if s == nil || s.v == nil {
return ""
}
- return string(*s.v)
+ return *s.v
}
func (s *stringValue) Predict(prefix string) []string {
@@ -190,13 +201,13 @@ func (i *intValue) Set(val string) error {
return i.Check(val)
}
-func (i *intValue) Get() interface{} { return int(*i.v) }
+func (i *intValue) Get() interface{} { return *i.v }
func (i *intValue) String() string {
if i == nil || i.v == nil {
return strconv.Itoa(0)
}
- return strconv.Itoa(int(*i.v))
+ return strconv.Itoa(*i.v)
}
func (s *intValue) Predict(prefix string) []string {
@@ -205,3 +216,38 @@ func (s *intValue) Predict(prefix string) []string {
}
return []string{""}
}
+
+type durationValue struct {
+ v *time.Duration
+ predict.Config
+}
+
+func newDurationValue(val time.Duration, p *time.Duration, c predict.Config) *durationValue {
+ *p = val
+ return &durationValue{v: p, Config: c}
+}
+
+func (i *durationValue) Set(val string) error {
+ v, err := time.ParseDuration(val)
+ *i.v = v
+ if err != nil {
+ return fmt.Errorf("bad value for duration flag")
+ }
+ return i.Check(val)
+}
+
+func (i *durationValue) Get() interface{} { return *i.v }
+
+func (i *durationValue) String() string {
+ if i == nil || i.v == nil {
+ return time.Duration(0).String()
+ }
+ return i.v.String()
+}
+
+func (s *durationValue) Predict(prefix string) []string {
+ if s.Predictor != nil {
+ return s.Predictor.Predict(prefix)
+ }
+ return []string{""}
+}
diff --git a/compflag/compflag_test.go b/compflag/compflag_test.go
index 73dc2c1..9d062e2 100644
--- a/compflag/compflag_test.go
+++ b/compflag/compflag_test.go
@@ -3,6 +3,7 @@ package compflag
import (
"flag"
"testing"
+ "time"
"github.com/posener/complete/v2"
"github.com/posener/complete/v2/predict"
@@ -104,3 +105,47 @@ func TestInt(t *testing.T) {
complete.Test(t, complete.FlagSet((*flag.FlagSet)(&cmd)), "-a=1", []string{"1"})
})
}
+
+func TestDuration(t *testing.T) {
+ t.Parallel()
+
+ t.Run("options invalid not checked", func(t *testing.T) {
+ var cmd FlagSet
+ value := cmd.Duration("a", 0, "", predict.OptValues("1s", "1m"))
+ err := cmd.Parse([]string{"-a", "1h"})
+ assert.NoError(t, err)
+ assert.Equal(t, time.Hour, *value)
+ })
+
+ t.Run("options valid checked", func(t *testing.T) {
+ var cmd FlagSet
+ value := cmd.Duration("a", 0, "", predict.OptValues("1s", "1m"), predict.OptCheck())
+ err := cmd.Parse([]string{"-a", "1m"})
+ assert.NoError(t, err)
+ assert.Equal(t, time.Minute, *value)
+ })
+
+ t.Run("options invalid checked", func(t *testing.T) {
+ var cmd FlagSet
+ _ = cmd.Duration("a", 0, "", predict.OptValues("1s", "1m"), predict.OptCheck())
+ err := cmd.Parse([]string{"-a", "1h"})
+ assert.Error(t, err)
+ })
+
+ t.Run("options invalid duration value", func(t *testing.T) {
+ var cmd FlagSet
+ _ = cmd.Duration("a", 0, "", predict.OptValues("1h", "1m", "1"), predict.OptCheck())
+ err := cmd.Parse([]string{"-a", "1"})
+ assert.Error(t, err)
+ })
+
+ t.Run("complete", func(t *testing.T) {
+ var cmd FlagSet
+ _ = cmd.Duration("a", 0, "", predict.OptValues("1s", "1m"))
+ complete.Test(t, complete.FlagSet((*flag.FlagSet)(&cmd)), "-a ", []string{"1s", "1m"})
+ complete.Test(t, complete.FlagSet((*flag.FlagSet)(&cmd)), "-a=", []string{"1s", "1m"})
+ complete.Test(t, complete.FlagSet((*flag.FlagSet)(&cmd)), "-a 1", []string{"1s", "1m"})
+ complete.Test(t, complete.FlagSet((*flag.FlagSet)(&cmd)), "-a=1", []string{"1s", "1m"})
+ complete.Test(t, complete.FlagSet((*flag.FlagSet)(&cmd)), "-a=1m", []string{"1m"})
+ })
+}