summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--command.go58
-rw-r--r--complete.go73
-rw-r--r--complete_test.go149
-rw-r--r--flag.go12
-rw-r--r--log.go17
5 files changed, 309 insertions, 0 deletions
diff --git a/command.go b/command.go
new file mode 100644
index 0000000..48dece3
--- /dev/null
+++ b/command.go
@@ -0,0 +1,58 @@
+package complete
+
+type Command struct {
+ Sub map[string]Command
+ Flags map[string]FlagOptions
+}
+
+// options returns all available complete options for the given command
+// args are all except the last command line arguments relevant to the command
+func (c *Command) options(args []string) (options []string, only bool) {
+
+ // remove the first argument, which is the command name
+ args = args[1:]
+
+ // if prev has something that needs to follow it,
+ // it is the most relevant completion
+ if options, ok := c.Flags[last(args)]; ok && options.HasFollow {
+ return options.FollowsOptions, true
+ }
+
+ sub, options, only := c.searchSub(args)
+ if only {
+ return
+ }
+
+ // if no subcommand was entered in any of the args, add the
+ // subcommands as complete options.
+ if sub == "" {
+ options = append(options, c.subCommands()...)
+ }
+
+ // add global available complete options
+ for flag := range c.Flags {
+ options = append(options, flag)
+ }
+
+ return
+}
+
+func (c *Command) searchSub(args []string) (sub string, all []string, only bool) {
+ for i, arg := range args {
+ if cmd, ok := c.Sub[arg]; ok {
+ sub = arg
+ all, only = cmd.options(args[i:])
+ return
+ }
+ }
+ return "", nil, false
+}
+
+func (c *Command) subCommands() []string {
+ subs := make([]string, 0, len(c.Sub))
+ for sub := range c.Sub {
+ subs = append(subs, sub)
+ }
+ return subs
+}
+
diff --git a/complete.go b/complete.go
new file mode 100644
index 0000000..eedaa81
--- /dev/null
+++ b/complete.go
@@ -0,0 +1,73 @@
+package complete
+
+import (
+ "fmt"
+ "os"
+ "strings"
+)
+
+const (
+ envComplete = "COMP_LINE"
+ envDebug = "COMP_DEBUG"
+)
+
+type Completer struct {
+ Command
+ log func(format string, args ...interface{})
+}
+
+func New(c Command) *Completer {
+ return &Completer{
+ Command: c,
+ log: logger(),
+ }
+}
+
+func (c *Completer) Complete() {
+ args := getLine()
+ c.log("Completing args: %s", args)
+
+ options := c.complete(args)
+
+ c.log("Completion: %s", options)
+ output(options)
+}
+
+func (c *Completer) complete(args []string) []string {
+ all, _ := c.options(args[:len(args)-1])
+ return c.chooseRelevant(last(args), all)
+}
+
+func (c *Completer) chooseRelevant(last string, list []string) (opts []string) {
+ if last == "" {
+ return list
+ }
+ for _, sub := range list {
+ if strings.HasPrefix(sub, last) {
+ opts = append(opts, sub)
+ }
+ }
+ return
+}
+
+func getLine() []string {
+ line := os.Getenv(envComplete)
+ if line == "" {
+ panic("should be run as a complete script")
+ }
+ return strings.Split(line, " ")
+}
+
+func last(args []string) (last string) {
+ if len(args) > 0 {
+ last = args[len(args)-1]
+ }
+ return
+}
+
+func output(options []string) {
+ // stdout of program defines the complete options
+ for _, option := range options {
+ fmt.Println(option)
+ }
+}
diff --git a/complete_test.go b/complete_test.go
new file mode 100644
index 0000000..55934bd
--- /dev/null
+++ b/complete_test.go
@@ -0,0 +1,149 @@
+package complete
+
+import (
+ "os"
+ "sort"
+ "testing"
+)
+
+func TestCompleter_Complete(t *testing.T) {
+ t.Parallel()
+
+ os.Setenv(envDebug, "1")
+
+ c := Completer{
+ Command: Command{
+ Sub: map[string]Command{
+ "sub1": {
+ Flags: map[string]FlagOptions{
+ "-flag1": FlagUnknownFollow,
+ "-flag2": FlagNoFollow,
+ },
+ },
+ "sub2": {
+ Flags: map[string]FlagOptions{
+ "-flag2": FlagNoFollow,
+ "-flag3": FlagNoFollow,
+ },
+ },
+ },
+ Flags: map[string]FlagOptions{
+ "-h": FlagNoFollow,
+ "-global1": FlagUnknownFollow,
+ },
+ },
+ log: t.Logf,
+ }
+
+ allGlobals := []string{}
+ for sub := range c.Sub {
+ allGlobals = append(allGlobals, sub)
+ }
+ for flag := range c.Flags {
+ allGlobals = append(allGlobals, flag)
+ }
+
+ tests := []struct {
+ args string
+ want []string
+ }{
+ {
+ args: "",
+ want: allGlobals,
+ },
+ {
+ args: "-",
+ want: []string{"-h", "-global1"},
+ },
+ {
+ args: "-h ",
+ want: allGlobals,
+ },
+ {
+ args: "-global1 ", // global1 is known follow flag
+ want: []string{},
+ },
+ {
+ args: "sub",
+ want: []string{"sub1", "sub2"},
+ },
+ {
+ args: "sub1",
+ want: []string{"sub1"},
+ },
+ {
+ args: "sub2",
+ want: []string{"sub2"},
+ },
+ {
+ args: "sub1 ",
+ want: []string{"-flag1", "-flag2", "-h", "-global1"},
+ },
+ {
+ args: "sub2 ",
+ want: []string{"-flag2", "-flag3", "-h", "-global1"},
+ },
+ {
+ args: "sub1 -fl",
+ want: []string{"-flag1", "-flag2"},
+ },
+ {
+ args: "sub1 -flag1",
+ want: []string{"-flag1"},
+ },
+ {
+ args: "sub1 -flag1 ",
+ want: []string{}, // flag1 is unknown follow flag
+ },
+ {
+ args: "sub1 -flag2 ",
+ want: []string{"-flag1", "-flag2", "-h", "-global1"},
+ },
+ {
+ args: "-no-such-flag",
+ want: []string{},
+ },
+ {
+ args: "-no-such-flag ",
+ want: allGlobals,
+ },
+ {
+ args: "no-such-command",
+ want: []string{},
+ },
+ {
+ args: "no-such-command ",
+ want: allGlobals,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.args, func(t *testing.T) {
+
+ tt.args = "cmd " + tt.args
+ os.Setenv(envComplete, tt.args)
+ args := getLine()
+
+ got := c.complete(args)
+
+ sort.Strings(tt.want)
+ sort.Strings(got)
+
+ if !equalSlices(got, tt.want) {
+ t.Errorf("failed '%s'\ngot = %s\nwant: %s", t.Name(), got, tt.want)
+ }
+ })
+ }
+}
+
+func equalSlices(a, b []string) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for i := range a {
+ if a[i] != b[i] {
+ return false
+ }
+ }
+ return true
+}
diff --git a/flag.go b/flag.go
new file mode 100644
index 0000000..2673c7a
--- /dev/null
+++ b/flag.go
@@ -0,0 +1,12 @@
+package complete
+
+type FlagOptions struct {
+ HasFollow bool
+ FollowsOptions []string
+}
+
+var (
+ FlagNoFollow = FlagOptions{}
+ FlagUnknownFollow = FlagOptions{HasFollow: true}
+)
+
diff --git a/log.go b/log.go
new file mode 100644
index 0000000..b3fdb8e
--- /dev/null
+++ b/log.go
@@ -0,0 +1,17 @@
+package complete
+
+import (
+ "io"
+ "io/ioutil"
+ "log"
+ "os"
+)
+
+
+func logger() func(format string, args ...interface{}) {
+ var logfile io.Writer = ioutil.Discard
+ if os.Getenv(envDebug) != "" {
+ logfile = os.Stderr
+ }
+ return log.New(logfile, "complete ", log.Flags()).Printf
+}