diff --git a/main.go b/main.go index 6b63dd7..2fb0668 100644 --- a/main.go +++ b/main.go @@ -17,12 +17,14 @@ import ( var version = "tip" type options struct { - Src string - Dst string - Package string - Namespace string - Insecure bool - Version bool + Src string + Dst string + Package string + Namespace string + Insecure bool + ClientCertFile string + ClientKeyFile string + Version bool } func main() { @@ -33,6 +35,8 @@ func main() { flag.StringVar(&opts.Namespace, "n", opts.Namespace, "override namespace") flag.StringVar(&opts.Package, "p", opts.Package, "package name") flag.BoolVar(&opts.Insecure, "yolo", opts.Insecure, "accept invalid https certificates") + flag.StringVar(&opts.ClientCertFile, "cert", opts.ClientCertFile, "use client TLS cert file") + flag.StringVar(&opts.ClientKeyFile, "key", opts.ClientKeyFile, "use client TLS key file") flag.BoolVar(&opts.Version, "version", opts.Version, "show version and exit") flag.Parse() if opts.Version { @@ -52,7 +56,7 @@ func main() { w = f } - cli := httpClient(opts.Insecure) + cli := httpClient(opts.Insecure, opts.ClientCertFile, opts.ClientKeyFile) err := codegen(w, opts, cli) if err != nil { @@ -99,7 +103,22 @@ func open(name string, cli *http.Client) (io.ReadCloser, error) { } // httpClient returns http client with default options -func httpClient(insecure bool) *http.Client { +func httpClient(insecure bool, clientCertPath, clientKeyPath string) *http.Client { + tlsConfig := &tls.Config{InsecureSkipVerify: insecure} + + if clientCertPath != "" && clientKeyPath != "" { + clientCert, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath) + if err != nil { + log.Fatalln("Failed to load x509 client key pair:", err) + } + tlsConfig.Certificates = []tls.Certificate{clientCert} + tlsConfig.Renegotiation = tls.RenegotiateFreelyAsClient + } else if clientCertPath == "" && clientKeyPath != "" { + log.Fatalln("Certificate file is required when using key file") + } else if clientCertPath != "" && clientKeyPath == "" { + log.Fatalln("Key file is required when using certificate file") + } + defaultTransport := http.DefaultTransport.(*http.Transport) transport := &http.Transport{ Proxy: defaultTransport.Proxy, @@ -108,7 +127,7 @@ func httpClient(insecure bool) *http.Client { IdleConnTimeout: defaultTransport.IdleConnTimeout, ExpectContinueTimeout: defaultTransport.ExpectContinueTimeout, TLSHandshakeTimeout: defaultTransport.TLSHandshakeTimeout, - TLSClientConfig: &tls.Config{InsecureSkipVerify: insecure}, + TLSClientConfig: tlsConfig, } return &http.Client{Transport: transport} }