diff options
| -rw-r--r-- | compflag/compflag.go | 58 | ||||
| -rw-r--r-- | compflag/compflag_test.go | 45 |
2 files changed, 97 insertions, 6 deletions
diff --git a/compflag/compflag.go b/compflag/compflag.go index 9128255..fa219c8 100644 --- a/compflag/compflag.go +++ b/compflag/compflag.go @@ -39,6 +39,7 @@ import ( "fmt" "os" "strconv" + "time" "github.com/posener/complete/v2" "github.com/posener/complete/v2/predict" @@ -75,6 +76,12 @@ func (fs *FlagSet) Int(name string, value int, usage string, options ...predict. return p } +func (fs *FlagSet) Duration(name string, value time.Duration, usage string, options ...predict.Option) *time.Duration { + p := new(time.Duration) + (*flag.FlagSet)(fs).Var(newDurationValue(value, p, predict.Options(options...)), name, usage) + return p +} + var CommandLine = (*FlagSet)(flag.CommandLine) // Parse parses command line arguments. It also performs bash completion when needed. @@ -95,6 +102,10 @@ func Int(name string, value int, usage string, options ...predict.Option) *int { return CommandLine.Int(name, value, usage, options...) } +func Duration(name string, value time.Duration, usage string, options ...predict.Option) *time.Duration { + return CommandLine.Duration(name, value, usage, options...) +} + type boolValue struct { v *bool predict.Config @@ -114,13 +125,13 @@ func (b *boolValue) Set(val string) error { return b.Check(val) } -func (b *boolValue) Get() interface{} { return bool(*b.v) } +func (b *boolValue) Get() interface{} { return *b.v } func (b *boolValue) String() string { if b == nil || b.v == nil { return strconv.FormatBool(false) } - return strconv.FormatBool(bool(*b.v)) + return strconv.FormatBool(*b.v) } func (b *boolValue) IsBoolFlag() bool { return true } @@ -154,14 +165,14 @@ func (s *stringValue) Set(val string) error { } func (s *stringValue) Get() interface{} { - return string(*s.v) + return *s.v } func (s *stringValue) String() string { if s == nil || s.v == nil { return "" } - return string(*s.v) + return *s.v } func (s *stringValue) Predict(prefix string) []string { @@ -190,13 +201,13 @@ func (i *intValue) Set(val string) error { return i.Check(val) } -func (i *intValue) Get() interface{} { return int(*i.v) } +func (i *intValue) Get() interface{} { return *i.v } func (i *intValue) String() string { if i == nil || i.v == nil { return strconv.Itoa(0) } - return strconv.Itoa(int(*i.v)) + return strconv.Itoa(*i.v) } func (s *intValue) Predict(prefix string) []string { @@ -205,3 +216,38 @@ func (s *intValue) Predict(prefix string) []string { } return []string{""} } + +type durationValue struct { + v *time.Duration + predict.Config +} + +func newDurationValue(val time.Duration, p *time.Duration, c predict.Config) *durationValue { + *p = val + return &durationValue{v: p, Config: c} +} + +func (i *durationValue) Set(val string) error { + v, err := time.ParseDuration(val) + *i.v = v + if err != nil { + return fmt.Errorf("bad value for duration flag") + } + return i.Check(val) +} + +func (i *durationValue) Get() interface{} { return *i.v } + +func (i *durationValue) String() string { + if i == nil || i.v == nil { + return time.Duration(0).String() + } + return i.v.String() +} + +func (s *durationValue) Predict(prefix string) []string { + if s.Predictor != nil { + return s.Predictor.Predict(prefix) + } + return []string{""} +} diff --git a/compflag/compflag_test.go b/compflag/compflag_test.go index 73dc2c1..9d062e2 100644 --- a/compflag/compflag_test.go +++ b/compflag/compflag_test.go @@ -3,6 +3,7 @@ package compflag import ( "flag" "testing" + "time" "github.com/posener/complete/v2" "github.com/posener/complete/v2/predict" @@ -104,3 +105,47 @@ func TestInt(t *testing.T) { complete.Test(t, complete.FlagSet((*flag.FlagSet)(&cmd)), "-a=1", []string{"1"}) }) } + +func TestDuration(t *testing.T) { + t.Parallel() + + t.Run("options invalid not checked", func(t *testing.T) { + var cmd FlagSet + value := cmd.Duration("a", 0, "", predict.OptValues("1s", "1m")) + err := cmd.Parse([]string{"-a", "1h"}) + assert.NoError(t, err) + assert.Equal(t, time.Hour, *value) + }) + + t.Run("options valid checked", func(t *testing.T) { + var cmd FlagSet + value := cmd.Duration("a", 0, "", predict.OptValues("1s", "1m"), predict.OptCheck()) + err := cmd.Parse([]string{"-a", "1m"}) + assert.NoError(t, err) + assert.Equal(t, time.Minute, *value) + }) + + t.Run("options invalid checked", func(t *testing.T) { + var cmd FlagSet + _ = cmd.Duration("a", 0, "", predict.OptValues("1s", "1m"), predict.OptCheck()) + err := cmd.Parse([]string{"-a", "1h"}) + assert.Error(t, err) + }) + + t.Run("options invalid duration value", func(t *testing.T) { + var cmd FlagSet + _ = cmd.Duration("a", 0, "", predict.OptValues("1h", "1m", "1"), predict.OptCheck()) + err := cmd.Parse([]string{"-a", "1"}) + assert.Error(t, err) + }) + + t.Run("complete", func(t *testing.T) { + var cmd FlagSet + _ = cmd.Duration("a", 0, "", predict.OptValues("1s", "1m")) + complete.Test(t, complete.FlagSet((*flag.FlagSet)(&cmd)), "-a ", []string{"1s", "1m"}) + complete.Test(t, complete.FlagSet((*flag.FlagSet)(&cmd)), "-a=", []string{"1s", "1m"}) + complete.Test(t, complete.FlagSet((*flag.FlagSet)(&cmd)), "-a 1", []string{"1s", "1m"}) + complete.Test(t, complete.FlagSet((*flag.FlagSet)(&cmd)), "-a=1", []string{"1s", "1m"}) + complete.Test(t, complete.FlagSet((*flag.FlagSet)(&cmd)), "-a=1m", []string{"1m"}) + }) +} |
