diff --git a/ReadMe.Md b/ReadMe.Md index 99049ea..30548e0 100644 --- a/ReadMe.Md +++ b/ReadMe.Md @@ -14,7 +14,7 @@ continuous integration scripts, never call sleep inside your pipeline again. [![Version](https://img.shields.io/badge/version-0.1.6-orange)](https://github.com/simonmittag/pwt/releases/tag/v0.1.6) ## What's New -### v0.1.6 +### v0.1.7 * -h -v flags for cli * bumped to go 1.20 diff --git a/cmd/pwt/main.go b/cmd/pwt/main.go index a604aa1..b909532 100644 --- a/cmd/pwt/main.go +++ b/cmd/pwt/main.go @@ -1,6 +1,7 @@ package main import ( + "errors" "flag" "fmt" "github.com/simonmittag/pwt" @@ -9,7 +10,6 @@ import ( "strings" ) -const default_host = "localhost" const default_port = 80 type Mode uint8 @@ -22,21 +22,27 @@ const ( func main() { mode := Server + host := "" + port := default_port timeSeconds := flag.Int("w", 10, "time wait in seconds") v := flag.Bool("v", false, "print pwt version") h := flag.Bool("h", false, "print usage instructions") flag.Usage = printUsage - flag.Parse() - - host, port := parseArgs(flag.Args()) - - if *v { - mode = Version - } else if *h { + err := ParseFlags() + if err != nil { mode = Usage } else { - mode = Server + a := flag.Args() + host, port, err = parseArgs(a) + + if *v { + mode = Version + } else if err != nil || *h { + mode = Usage + } else { + mode = Server + } } switch mode { @@ -51,7 +57,7 @@ func main() { func printUsage() { printVersion() - fmt.Printf("Usage: pwt [-v]|[-w n] host[:port]\n") + fmt.Printf("Usage: pwt [-h]|[-v]|[-w n] host[:port]\n") flag.PrintDefaults() } @@ -66,8 +72,8 @@ func wait(host string, port int, timeSeconds int) { } } -func parseArgs(args []string) (string, int) { - var host string = default_host +func parseArgs(args []string) (string, int, error) { + var host = "" var port int = default_port if len(args) == 1 { @@ -75,13 +81,46 @@ func parseArgs(args []string) (string, int) { if (strings.Contains(dest, ":") && !strings.Contains(dest, "::")) || strings.Contains(dest, "]:") { ci := strings.LastIndex(args[0], ":") host = args[0][0:ci] - port, _ = strconv.Atoi(args[0][ci+1:]) + p, _ := strconv.Atoi(args[0][ci+1:]) + if p > 0 && p < 65535 { + port = p + } } else { - host = args[0] + host = dest + } + return host, port, nil + } else { + return "", 0, errors.New("invalid host or port") + } +} + +// ParseFlags parses the command line args, allowing flags to be +// specified after positional args. +func ParseFlags() error { + return ParseFlagSet(flag.CommandLine, os.Args[1:]) +} + +// ParseFlagSet works like flagset.Parse(), except positional arguments are not +// required to come after flag arguments. +func ParseFlagSet(flagset *flag.FlagSet, args []string) error { + var positionalArgs []string + for { + if err := flagset.Parse(args); err != nil { + return err + } + // Consume all the flags that were parsed as flags. + args = args[len(args)-flagset.NArg():] + if len(args) == 0 { + break } - } else if len(args) == 2 { - host = args[0] - port, _ = strconv.Atoi(args[1]) + // There's at least one flag remaining and it must be a positional arg since + // we consumed all args that were parsed as flags. Consume just the first + // one, and retry parsing, since subsequent args may be flags. + positionalArgs = append(positionalArgs, args[0]) + args = args[1:] } - return host, port + // Parse just the positional args so that flagset.Args()/flagset.NArgs() + // return the expected value. + // Note: This should never return an error. + return flagset.Parse(positionalArgs) } diff --git a/cmd/pwt/main_test.go b/cmd/pwt/main_test.go index 042b97d..27b0870 100644 --- a/cmd/pwt/main_test.go +++ b/cmd/pwt/main_test.go @@ -3,6 +3,7 @@ package main import ( "os" "testing" + "time" ) //TODO: cannot run these tests because multiple invocations of flag.Arg() @@ -16,78 +17,103 @@ import ( // main() //} -func TestMainFuncWithArgs(t *testing.T) { - os.Args = append([]string{""}, "www.google.com", "443") +func TestMainFuncWithPositionalArgsBeforeFlagArgs(t *testing.T) { + //first arg is command + before := time.Now() + os.Args = []string{"pwt", "www.google.com:888", "-w", "4"} main() + + after := time.Now() + got := after.Sub(before).Seconds() + want := time.Duration(time.Second * 4).Seconds() + if float64(got) < float64(want) { + t.Errorf("wanted at least %v seconds, but got %v", want, got) + } + } func TestParseArgsHostNameColonPort(t *testing.T) { - host, port := parseArgs([]string{"somehost:8083"}) + host, port, err := parseArgs([]string{"somehost:8083"}) if host != "somehost" { t.Errorf("did not resolve host want somehost, got %s", host) } if port != 8083 { t.Errorf("did not resolve port want 8083 got %d", port) } + if err != nil { + t.Errorf("should not have errored, but got %v", err) + } } func TestParseArgsHostName(t *testing.T) { - host, port := parseArgs([]string{"myhost.com"}) + host, port, err := parseArgs([]string{"myhost.com"}) if host != "myhost.com" { t.Errorf("did not resolve host want myhost.com got %s", host) } if port != default_port { t.Errorf("did not resolve port want %d, got %d", default_port, port) } + if err != nil { + t.Errorf("should not have errored, but got %v", err) + } } func TestParseArgsIpv4(t *testing.T) { - host, port := parseArgs([]string{"127.0.0.1"}) + host, port, err := parseArgs([]string{"127.0.0.1"}) if host != "127.0.0.1" { t.Errorf("did not resolve host want 127.0.0.1 got %s", host) } if port != default_port { t.Errorf("did not resolve port want %d, got %d", default_port, port) } + if err != nil { + t.Errorf("should not have errored, but got %v", err) + } } func TestParseArgsIpv4ColonPort(t *testing.T) { - host, port := parseArgs([]string{"127.0.0.1:8081"}) + host, port, err := parseArgs([]string{"127.0.0.1:8081"}) if host != "127.0.0.1" { t.Errorf("did not resolve host want 127.0.0.1 got %s", host) } if port != 8081 { t.Errorf("did not resolve port want %d, got %d", 8081, port) } + if err != nil { + t.Errorf("should not have errored, but got %v", err) + } } func TestParseArgsIpv6(t *testing.T) { - host, port := parseArgs([]string{"[::1]"}) + host, port, err := parseArgs([]string{"[::1]"}) if host != "[::1]" { t.Errorf("did not resolve host want [::1] got %s", host) } if port != default_port { t.Errorf("did not resolve port want %d, got %d", default_port, port) } + if err != nil { + t.Errorf("should not have errored, but got %v", err) + } } func TestParseArgsIpv6ColonPort(t *testing.T) { - host, port := parseArgs([]string{"[::1]:8087"}) + host, port, err := parseArgs([]string{"[::1]:8087"}) if host != "[::1]" { t.Errorf("did not resolve host want [::1] got %s", host) } if port != 8087 { t.Errorf("did not resolve port want %d, got %d", 8087, port) } + if err != nil { + t.Errorf("should not have errored, but got %v", err) + } } func TestParseArgsHostAndPort(t *testing.T) { - host, port := parseArgs([]string{"hostname.com", "8089"}) - if host != "hostname.com" { - t.Errorf("did not resolve host want [::1] got %s", host) - } - if port != 8089 { - t.Errorf("did not resolve port want %d, got %d", 8087, port) + _, _, err := parseArgs([]string{"hostname.com", "8089"}) + if err == nil { + t.Errorf("should have failed %v", err) } } diff --git a/pwt.go b/pwt.go index 9fe0468..85e92ed 100644 --- a/pwt.go +++ b/pwt.go @@ -6,7 +6,7 @@ import ( "time" ) -const Version string = "v0.1.6" +const Version string = "v0.1.7" var dialler = &net.Dialer{ Timeout: 1 * time.Second,