summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Flint <[email protected]>2022-10-10 08:41:57 -0700
committerGitHub <[email protected]>2022-10-10 08:41:57 -0700
commitdbc2ba5d0c9a6a439d1f825f8c299fb276bbc911 (patch)
tree3c922731582aa6a69241607838bcb52d783457ce
parent11f9b624a9e5ee803b3810d2329fb027662c69f4 (diff)
parent4fc9666f79d7a9a84be9963e34b5f2479cd340b6 (diff)
Merge pull request #198 from daenney/mustparse
Implement MustParse on Parser
-rw-r--r--parse.go27
-rw-r--r--parse_test.go48
2 files changed, 63 insertions, 12 deletions
diff --git a/parse.go b/parse.go
index c8cd79e..dc87947 100644
--- a/parse.go
+++ b/parse.go
@@ -82,18 +82,7 @@ func MustParse(dest ...interface{}) *Parser {
return nil // just in case osExit was monkey-patched
}
- err = p.Parse(flags())
- switch {
- case err == ErrHelp:
- p.writeHelpForSubcommand(stdout, p.lastCmd)
- osExit(0)
- case err == ErrVersion:
- fmt.Fprintln(stdout, p.version)
- osExit(0)
- case err != nil:
- p.failWithSubcommand(err.Error(), p.lastCmd)
- }
-
+ p.MustParse(flags())
return p
}
@@ -449,6 +438,20 @@ func (p *Parser) Parse(args []string) error {
return err
}
+func (p *Parser) MustParse(args []string) {
+ err := p.Parse(args)
+ switch {
+ case err == ErrHelp:
+ p.writeHelpForSubcommand(stdout, p.lastCmd)
+ osExit(0)
+ case err == ErrVersion:
+ fmt.Fprintln(stdout, p.version)
+ osExit(0)
+ case err != nil:
+ p.failWithSubcommand(err.Error(), p.lastCmd)
+ }
+}
+
// process environment vars for the given arguments
func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error {
for _, spec := range specs {
diff --git a/parse_test.go b/parse_test.go
index 4ea6bc4..7e84def 100644
--- a/parse_test.go
+++ b/parse_test.go
@@ -860,6 +860,54 @@ func TestEnvironmentVariableInSubcommandIgnored(t *testing.T) {
assert.Equal(t, "", args.Sub.Foo)
}
+func TestParserMustParseEmptyArgs(t *testing.T) {
+ // this mirrors TestEmptyArgs
+ p, err := NewParser(Config{}, &struct{}{})
+ require.NoError(t, err)
+ assert.NotNil(t, p)
+ p.MustParse(nil)
+}
+
+func TestParserMustParse(t *testing.T) {
+ tests := []struct {
+ name string
+ args versioned
+ cmdLine []string
+ code int
+ output string
+ }{
+ {name: "help", args: struct{}{}, cmdLine: []string{"--help"}, code: 0, output: "display this help and exit"},
+ {name: "version", args: versioned{}, cmdLine: []string{"--version"}, code: 0, output: "example 3.2.1"},
+ {name: "invalid", args: struct{}{}, cmdLine: []string{"invalid"}, code: -1, output: ""},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ originalExit := osExit
+ originalStdout := stdout
+ defer func() {
+ osExit = originalExit
+ stdout = originalStdout
+ }()
+
+ var exitCode *int
+ osExit = func(code int) { exitCode = &code }
+ var b bytes.Buffer
+ stdout = &b
+
+ p, err := NewParser(Config{}, &tt.args)
+ require.NoError(t, err)
+ assert.NotNil(t, p)
+
+ p.MustParse(tt.cmdLine)
+ assert.NotNil(t, exitCode)
+ assert.Equal(t, tt.code, *exitCode)
+ assert.Contains(t, b.String(), tt.output)
+ })
+ }
+}
+
type textUnmarshaler struct {
val int
}