summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Flint <[email protected]>2023-02-08 06:56:56 -0800
committerGitHub <[email protected]>2023-02-08 06:56:56 -0800
commite25b4707a7d6c63ff6910c1e1bcb416cb8debfb2 (patch)
tree1077d83c624e955733ce4e695465f084f0be9539
parent5dbdd5d0c585feef8723223464a9362d635324f6 (diff)
parentdf28e7154bbab76436bc59e5dc67fb6d6824fc62 (diff)
Merge pull request #211 from alexflint/clean-up-osexit-stderr-stdout
clean up customizable stdout, stderr, and exit in parser config
-rw-r--r--example_test.go31
-rw-r--r--parse.go65
-rw-r--r--parse_test.go69
-rw-r--r--usage.go14
-rw-r--r--usage_test.go34
5 files changed, 80 insertions, 133 deletions
diff --git a/example_test.go b/example_test.go
index fd64777..5272393 100644
--- a/example_test.go
+++ b/example_test.go
@@ -162,8 +162,7 @@ func Example_helpText() {
}
// This is only necessary when running inside golang's runnable example harness
- osExit = func(int) {}
- stdout = os.Stdout
+ mustParseExit = func(int) {}
MustParse(&args)
@@ -195,8 +194,7 @@ func Example_helpPlaceholder() {
}
// This is only necessary when running inside golang's runnable example harness
- osExit = func(int) {}
- stdout = os.Stdout
+ mustParseExit = func(int) {}
MustParse(&args)
@@ -236,8 +234,7 @@ func Example_helpTextWithSubcommand() {
}
// This is only necessary when running inside golang's runnable example harness
- osExit = func(int) {}
- stdout = os.Stdout
+ mustParseExit = func(int) {}
MustParse(&args)
@@ -274,8 +271,7 @@ func Example_helpTextWhenUsingSubcommand() {
}
// This is only necessary when running inside golang's runnable example harness
- osExit = func(int) {}
- stdout = os.Stdout
+ mustParseExit = func(int) {}
MustParse(&args)
@@ -311,10 +307,9 @@ func Example_writeHelpForSubcommand() {
}
// This is only necessary when running inside golang's runnable example harness
- osExit = func(int) {}
- stdout = os.Stdout
+ exit := func(int) {}
- p, err := NewParser(Config{}, &args)
+ p, err := NewParser(Config{Exit: exit}, &args)
if err != nil {
fmt.Println(err)
os.Exit(1)
@@ -360,10 +355,9 @@ func Example_writeHelpForSubcommandNested() {
}
// This is only necessary when running inside golang's runnable example harness
- osExit = func(int) {}
- stdout = os.Stdout
+ exit := func(int) {}
- p, err := NewParser(Config{}, &args)
+ p, err := NewParser(Config{Exit: exit}, &args)
if err != nil {
fmt.Println(err)
os.Exit(1)
@@ -397,8 +391,7 @@ func Example_errorText() {
}
// This is only necessary when running inside golang's runnable example harness
- osExit = func(int) {}
- stderr = os.Stdout
+ mustParseExit = func(int) {}
MustParse(&args)
@@ -421,8 +414,7 @@ func Example_errorTextForSubcommand() {
}
// This is only necessary when running inside golang's runnable example harness
- osExit = func(int) {}
- stderr = os.Stdout
+ mustParseExit = func(int) {}
MustParse(&args)
@@ -457,8 +449,7 @@ func Example_subcommand() {
}
// This is only necessary when running inside golang's runnable example harness
- osExit = func(int) {}
- stderr = os.Stdout
+ mustParseExit = func(int) {}
MustParse(&args)
diff --git a/parse.go b/parse.go
index 6d8b509..be77924 100644
--- a/parse.go
+++ b/parse.go
@@ -75,13 +75,28 @@ var ErrHelp = errors.New("help requested by user")
// ErrVersion indicates that --version was provided
var ErrVersion = errors.New("version requested by user")
+// for monkey patching in example code
+var mustParseExit = os.Exit
+
// MustParse processes command line arguments and exits upon failure
func MustParse(dest ...interface{}) *Parser {
- p, err := NewParser(Config{}, dest...)
+ return mustParse(Config{Exit: mustParseExit}, dest...)
+}
+
+// mustParse is a helper that facilitates testing
+func mustParse(config Config, dest ...interface{}) *Parser {
+ if config.Exit == nil {
+ config.Exit = os.Exit
+ }
+ if config.Out == nil {
+ config.Out = os.Stdout
+ }
+
+ p, err := NewParser(config, dest...)
if err != nil {
- fmt.Fprintln(stdout, err)
- osExit(-1)
- return nil // just in case osExit was monkey-patched
+ fmt.Fprintln(config.Out, err)
+ config.Exit(-1)
+ return nil
}
p.MustParse(flags())
@@ -121,9 +136,11 @@ type Config struct {
// subcommand
StrictSubcommands bool
- OsExit func(int)
- Stdout io.Writer
- Stderr io.Writer
+ // Exit is called to terminate the process with an error code (defaults to os.Exit)
+ Exit func(int)
+
+ // Out is where help text, usage text, and failure messages are printed (defaults to os.Stdout)
+ Out io.Writer
}
// Parser represents a set of command line options with destination values
@@ -137,10 +154,6 @@ type Parser struct {
// the following field changes during processing of command line arguments
lastCmd *command
-
- osExit func(int)
- stdout io.Writer
- stderr io.Writer
}
// Versioned is the interface that the destination struct should implement to
@@ -190,6 +203,14 @@ func walkFieldsImpl(t reflect.Type, visit func(field reflect.StructField, owner
// NewParser constructs a parser from a list of destination structs
func NewParser(config Config, dests ...interface{}) (*Parser, error) {
+ // fill in defaults
+ if config.Exit == nil {
+ config.Exit = os.Exit
+ }
+ if config.Out == nil {
+ config.Out = os.Stdout
+ }
+
// first pick a name for the command for use in the usage text
var name string
switch {
@@ -205,20 +226,6 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
p := Parser{
cmd: &command{name: name},
config: config,
-
- osExit: osExit,
- stdout: stdout,
- stderr: stderr,
- }
-
- if config.OsExit != nil {
- p.osExit = config.OsExit
- }
- if config.Stdout != nil {
- p.stdout = config.Stdout
- }
- if config.Stderr != nil {
- p.stderr = config.Stderr
}
// make a list of roots
@@ -506,11 +513,11 @@ func (p *Parser) MustParse(args []string) {
err := p.Parse(args)
switch {
case err == ErrHelp:
- p.writeHelpForSubcommand(p.stdout, p.lastCmd)
- p.osExit(0)
+ p.writeHelpForSubcommand(p.config.Out, p.lastCmd)
+ p.config.Exit(0)
case err == ErrVersion:
- fmt.Fprintln(p.stdout, p.version)
- p.osExit(0)
+ fmt.Fprintln(p.config.Out, p.version)
+ p.config.Exit(0)
case err != nil:
p.failWithSubcommand(err.Error(), p.lastCmd)
}
diff --git a/parse_test.go b/parse_test.go
index 64119a8..d368b17 100644
--- a/parse_test.go
+++ b/parse_test.go
@@ -885,26 +885,18 @@ func TestParserMustParse(t *testing.T) {
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
- originalExit := osExit
- originalStdout := stdout
- defer func() {
- osExit = originalExit
- stdout = originalStdout
- }()
+ var exitCode int
+ var stdout bytes.Buffer
+ exit := func(code int) { exitCode = code }
- var exitCode *int
- osExit = func(code int) { exitCode = &code }
- var b bytes.Buffer
- stdout = &b
-
- p, err := NewParser(Config{}, &tt.args)
+ p, err := NewParser(Config{Exit: exit, Out: &stdout}, &tt.args)
require.NoError(t, err)
assert.NotNil(t, p)
p.MustParse(tt.cmdLine)
assert.NotNil(t, exitCode)
- assert.Equal(t, tt.code, *exitCode)
- assert.Contains(t, b.String(), tt.output)
+ assert.Equal(t, tt.code, exitCode)
+ assert.Contains(t, stdout.String(), tt.output)
})
}
}
@@ -1484,70 +1476,53 @@ func TestUnexportedFieldsSkipped(t *testing.T) {
}
func TestMustParseInvalidParser(t *testing.T) {
- originalExit := osExit
- originalStdout := stdout
- defer func() {
- osExit = originalExit
- stdout = originalStdout
- }()
-
var exitCode int
- osExit = func(code int) { exitCode = code }
- stdout = &bytes.Buffer{}
+ var stdout bytes.Buffer
+ exit := func(code int) { exitCode = code }
var args struct {
CannotParse struct{}
}
- parser := MustParse(&args)
+ parser := mustParse(Config{Out: &stdout, Exit: exit}, &args)
assert.Nil(t, parser)
assert.Equal(t, -1, exitCode)
}
func TestMustParsePrintsHelp(t *testing.T) {
- originalExit := osExit
- originalStdout := stdout
originalArgs := os.Args
defer func() {
- osExit = originalExit
- stdout = originalStdout
os.Args = originalArgs
}()
- var exitCode *int
- osExit = func(code int) { exitCode = &code }
os.Args = []string{"someprogram", "--help"}
- stdout = &bytes.Buffer{}
+
+ var exitCode int
+ var stdout bytes.Buffer
+ exit := func(code int) { exitCode = code }
var args struct{}
- parser := MustParse(&args)
+ parser := mustParse(Config{Out: &stdout, Exit: exit}, &args)
assert.NotNil(t, parser)
- require.NotNil(t, exitCode)
- assert.Equal(t, 0, *exitCode)
+ assert.Equal(t, 0, exitCode)
}
func TestMustParsePrintsVersion(t *testing.T) {
- originalExit := osExit
- originalStdout := stdout
originalArgs := os.Args
defer func() {
- osExit = originalExit
- stdout = originalStdout
os.Args = originalArgs
}()
- var exitCode *int
- osExit = func(code int) { exitCode = &code }
- os.Args = []string{"someprogram", "--version"}
+ var exitCode int
+ var stdout bytes.Buffer
+ exit := func(code int) { exitCode = code }
- var b bytes.Buffer
- stdout = &b
+ os.Args = []string{"someprogram", "--version"}
var args versioned
- parser := MustParse(&args)
+ parser := mustParse(Config{Out: &stdout, Exit: exit}, &args)
require.NotNil(t, parser)
- require.NotNil(t, exitCode)
- assert.Equal(t, 0, *exitCode)
- assert.Equal(t, "example 3.2.1\n", b.String())
+ assert.Equal(t, 0, exitCode)
+ assert.Equal(t, "example 3.2.1\n", stdout.String())
}
type mapWithUnmarshalText struct {
diff --git a/usage.go b/usage.go
index 80eba45..43d6231 100644
--- a/usage.go
+++ b/usage.go
@@ -3,20 +3,12 @@ package arg
import (
"fmt"
"io"
- "os"
"strings"
)
// the width of the left column
const colWidth = 25
-// to allow monkey patching in tests
-var (
- stdout io.Writer = os.Stdout
- stderr io.Writer = os.Stderr
- osExit = os.Exit
-)
-
// Fail prints usage information to stderr and exits with non-zero status
func (p *Parser) Fail(msg string) {
p.failWithSubcommand(msg, p.cmd)
@@ -39,9 +31,9 @@ func (p *Parser) FailSubcommand(msg string, subcommand ...string) error {
// failWithSubcommand prints usage information for the given subcommand to stderr and exits with non-zero status
func (p *Parser) failWithSubcommand(msg string, cmd *command) {
- p.writeUsageForSubcommand(p.stderr, cmd)
- fmt.Fprintln(p.stderr, "error:", msg)
- p.osExit(-1)
+ p.writeUsageForSubcommand(p.config.Out, cmd)
+ fmt.Fprintln(p.config.Out, "error:", msg)
+ p.config.Exit(-1)
}
// WriteUsage writes usage information to the given writer
diff --git a/usage_test.go b/usage_test.go
index be5894a..69feac2 100644
--- a/usage_test.go
+++ b/usage_test.go
@@ -572,18 +572,9 @@ Options:
}
func TestFail(t *testing.T) {
- originalStderr := stderr
- originalExit := osExit
- defer func() {
- stderr = originalStderr
- osExit = originalExit
- }()
-
- var b bytes.Buffer
- stderr = &b
-
+ var stdout bytes.Buffer
var exitCode int
- osExit = func(code int) { exitCode = code }
+ exit := func(code int) { exitCode = code }
expectedStdout := `
Usage: example [--foo FOO]
@@ -593,27 +584,18 @@ error: something went wrong
var args struct {
Foo int
}
- p, err := NewParser(Config{Program: "example"}, &args)
+ p, err := NewParser(Config{Program: "example", Exit: exit, Out: &stdout}, &args)
require.NoError(t, err)
p.Fail("something went wrong")
- assert.Equal(t, expectedStdout[1:], b.String())
+ assert.Equal(t, expectedStdout[1:], stdout.String())
assert.Equal(t, -1, exitCode)
}
func TestFailSubcommand(t *testing.T) {
- originalStderr := stderr
- originalExit := osExit
- defer func() {
- stderr = originalStderr
- osExit = originalExit
- }()
-
- var b bytes.Buffer
- stderr = &b
-
+ var stdout bytes.Buffer
var exitCode int
- osExit = func(code int) { exitCode = code }
+ exit := func(code int) { exitCode = code }
expectedStdout := `
Usage: example sub
@@ -623,13 +605,13 @@ error: something went wrong
var args struct {
Sub *struct{} `arg:"subcommand"`
}
- p, err := NewParser(Config{Program: "example"}, &args)
+ p, err := NewParser(Config{Program: "example", Exit: exit, Out: &stdout}, &args)
require.NoError(t, err)
err = p.FailSubcommand("something went wrong", "sub")
require.NoError(t, err)
- assert.Equal(t, expectedStdout[1:], b.String())
+ assert.Equal(t, expectedStdout[1:], stdout.String())
assert.Equal(t, -1, exitCode)
}