summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--parse.go114
-rw-r--r--parse_test.go15
-rw-r--r--usage.go6
3 files changed, 79 insertions, 56 deletions
diff --git a/parse.go b/parse.go
index c4afda2..e1d1b29 100644
--- a/parse.go
+++ b/parse.go
@@ -24,7 +24,6 @@ type spec struct {
separate bool
help string
env string
- wasPresent bool
boolean bool
}
@@ -80,7 +79,7 @@ type Config struct {
// Parser represents a set of command line options with destination values
type Parser struct {
- spec []*spec
+ specs []*spec
config Config
version string
description string
@@ -214,7 +213,7 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
}
}
}
- p.spec = append(p.spec, &spec)
+ p.specs = append(p.specs, &spec)
// if this was an embedded field then we already returned true up above
return false
@@ -250,21 +249,18 @@ func (p *Parser) Parse(args []string) error {
}
// Process all command line arguments
- err := process(p.spec, args)
- if err != nil {
- return err
- }
-
- // Validate
- return validate(p.spec)
+ return p.process(args)
}
// process goes through arguments one-by-one, parses them, and assigns the result to
// the underlying struct field
-func process(specs []*spec, args []string) error {
+func (p *Parser) process(args []string) error {
+ // track the options we have seen
+ wasPresent := make(map[*spec]bool)
+
// construct a map from --option to spec
optionMap := make(map[string]*spec)
- for _, spec := range specs {
+ for _, spec := range p.specs {
if spec.positional {
continue
}
@@ -274,34 +270,43 @@ func process(specs []*spec, args []string) error {
if spec.short != "" {
optionMap[spec.short] = spec
}
- if spec.env != "" {
- if value, found := os.LookupEnv(spec.env); found {
- if spec.multiple {
- // expect a CSV string in an environment
- // variable in the case of multiple values
- values, err := csv.NewReader(strings.NewReader(value)).Read()
- if err != nil {
- return fmt.Errorf(
- "error reading a CSV string from environment variable %s with multiple values: %v",
- spec.env,
- err,
- )
- }
- if err = setSlice(spec.dest, values, !spec.separate); err != nil {
- return fmt.Errorf(
- "error processing environment variable %s with multiple values: %v",
- spec.env,
- err,
- )
- }
- } else {
- if err := scalar.ParseValue(spec.dest, value); err != nil {
- return fmt.Errorf("error processing environment variable %s: %v", spec.env, err)
- }
- }
- spec.wasPresent = true
+ }
+
+ // deal with environment vars
+ for _, spec := range p.specs {
+ if spec.env == "" {
+ continue
+ }
+
+ value, found := os.LookupEnv(spec.env)
+ if !found {
+ continue
+ }
+
+ if spec.multiple {
+ // expect a CSV string in an environment
+ // variable in the case of multiple values
+ values, err := csv.NewReader(strings.NewReader(value)).Read()
+ if err != nil {
+ return fmt.Errorf(
+ "error reading a CSV string from environment variable %s with multiple values: %v",
+ spec.env,
+ err,
+ )
+ }
+ if err = setSlice(spec.dest, values, !spec.separate); err != nil {
+ return fmt.Errorf(
+ "error processing environment variable %s with multiple values: %v",
+ spec.env,
+ err,
+ )
+ }
+ } else {
+ if err := scalar.ParseValue(spec.dest, value); err != nil {
+ return fmt.Errorf("error processing environment variable %s: %v", spec.env, err)
}
}
+ wasPresent[spec] = true
}
// process each string from the command line
@@ -334,7 +339,7 @@ func process(specs []*spec, args []string) error {
if !ok {
return fmt.Errorf("unknown argument %s", arg)
}
- spec.wasPresent = true
+ wasPresent[spec] = true
// deal with the case of multiple values
if spec.multiple {
@@ -382,20 +387,21 @@ func process(specs []*spec, args []string) error {
}
// process positionals
- for _, spec := range specs {
+ for _, spec := range p.specs {
if !spec.positional {
continue
}
- if spec.required && len(positionals) == 0 {
- return fmt.Errorf("%s is required", spec.long)
+ if len(positionals) == 0 {
+ break
}
+ wasPresent[spec] = true
if spec.multiple {
err := setSlice(spec.dest, positionals, true)
if err != nil {
return fmt.Errorf("error processing %s: %v", spec.long, err)
}
positionals = nil
- } else if len(positionals) > 0 {
+ } else {
err := scalar.ParseValue(spec.dest, positionals[0])
if err != nil {
return fmt.Errorf("error processing %s: %v", spec.long, err)
@@ -406,6 +412,18 @@ func process(specs []*spec, args []string) error {
if len(positionals) > 0 {
return fmt.Errorf("too many positional arguments at '%s'", positionals[0])
}
+
+ // finally check that all the required args were provided
+ for _, spec := range p.specs {
+ if spec.required && !wasPresent[spec] {
+ name := spec.long
+ if !spec.positional {
+ name = "--" + spec.long
+ }
+ return fmt.Errorf("%s is required", name)
+ }
+ }
+
return nil
}
@@ -427,16 +445,6 @@ func isFlag(s string) bool {
return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != ""
}
-// validate an argument spec after arguments have been parse
-func validate(spec []*spec) error {
- for _, arg := range spec {
- if !arg.positional && arg.required && !arg.wasPresent {
- return fmt.Errorf("--%s is required", arg.long)
- }
- }
- return nil
-}
-
// parse a value as the appropriate type and store it in the struct
func setSlice(dest reflect.Value, values []string, trunc bool) error {
if !dest.CanSet() {
diff --git a/parse_test.go b/parse_test.go
index 2e438aa..81cd2c3 100644
--- a/parse_test.go
+++ b/parse_test.go
@@ -969,3 +969,18 @@ func TestSpacesAllowedInTags(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, []string{"one", "two", "three", "four"}, args.Foo)
}
+
+func TestReuseParser(t *testing.T) {
+ var args struct {
+ Foo string `arg:"required"`
+ }
+
+ p, err := NewParser(Config{}, &args)
+ require.NoError(t, err)
+
+ err = p.Parse([]string{"--foo=abc"})
+ assert.Equal(t, args.Foo, "abc")
+
+ err = p.Parse([]string{})
+ assert.Error(t, err)
+}
diff --git a/usage.go b/usage.go
index 656ee9a..cfac563 100644
--- a/usage.go
+++ b/usage.go
@@ -1,12 +1,12 @@
package arg
import (
+ "encoding"
"fmt"
"io"
"os"
"reflect"
"strings"
- "encoding"
)
// the width of the left column
@@ -22,7 +22,7 @@ func (p *Parser) Fail(msg string) {
// WriteUsage writes usage information to the given writer
func (p *Parser) WriteUsage(w io.Writer) {
var positionals, options []*spec
- for _, spec := range p.spec {
+ for _, spec := range p.specs {
if spec.positional {
positionals = append(positionals, spec)
} else {
@@ -72,7 +72,7 @@ func (p *Parser) WriteUsage(w io.Writer) {
// WriteHelp writes the usage string followed by the full help string for each option
func (p *Parser) WriteHelp(w io.Writer) {
var positionals, options []*spec
- for _, spec := range p.spec {
+ for _, spec := range p.specs {
if spec.positional {
positionals = append(positionals, spec)
} else {