diff --git a/.golangci.yml b/.golangci.yml index 2b75b3a143c9..c24f2e179ecd 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -109,3 +109,9 @@ issues: - linters: - staticcheck text: "^SA1019:" + + # tcpproxy is copied from https://github.com/inetaf/tcpproxy/, as per + # Apache 2.0 license section 4 (Redistribution) we must keep the original header. + - path: "pkg/component/controller/cplb/tcpproxy/.*" + linters: + - goheader diff --git a/hack/copyright.sh b/hack/copyright.sh index 5a494c13e05b..dd4d701d35e2 100755 --- a/hack/copyright.sh +++ b/hack/copyright.sh @@ -45,6 +45,10 @@ has_date_copyright(){ # Copyright notice aren't related to the date of the document. for i in $(find cmd hack internal inttest pkg static -type f -name '*.go' -not -name 'zz_generated*'); do case "$i" in + pkg/component/controller/cplb/tcpproxy/*) + # These files have a special copyright due to being copied + # from github.com/inetaf/tcpproxy + ;; pkg/client/clientset/*) if ! has_basic_copyright "$i"; then echo "ERROR: $i doesn't have a proper copyright notice" 1>&2 diff --git a/pkg/component/controller/cplb/tcpproxy/tcpproxy.go b/pkg/component/controller/cplb/tcpproxy/tcpproxy.go new file mode 100644 index 000000000000..8691f5afeaba --- /dev/null +++ b/pkg/component/controller/cplb/tcpproxy/tcpproxy.go @@ -0,0 +1,492 @@ +// Copyright 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Modifications made by Mirantis Inc., 2024. +// Copyright 2017 Google Inc. +// +// Copyright 2024 Mirantis, Inc. + +// Package tcpproxy lets users build TCP proxies + +package tcpproxy + +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "log" + "net" + "time" +) + +// Proxy is a proxy. Its zero value is a valid proxy that does +// nothing. Call methods to add routes before calling Start or Run. +// +// The order that routes are added in matters; each is matched in the order +// registered. +type Proxy struct { + configs map[string]*config // ip:port => config + + lns []net.Listener + donec chan struct{} // closed before err + err error // any error from listening + routesChan chan route + + // ListenFunc optionally specifies an alternate listen + // function. If nil, net.Dial is used. + // The provided net is always "tcp". + ListenFunc func(net, laddr string) (net.Listener, error) +} + +// Matcher reports whether hostname matches the Matcher's criteria. +type Matcher func(ctx context.Context, hostname string) bool + +// config contains the proxying state for one listener. +type config struct { + routes []route +} + +// A route matches a connection to a target. +type route interface { + // match examines the initial bytes of a connection, looking for a + // match. If a match is found, match returns a non-nil Target to + // which the stream should be proxied. match returns nil if the + // connection doesn't match. + // + // match must not consume bytes from the given bufio.Reader, it + // can only Peek. + // + // If an sni or host header was parsed successfully, that will be + // returned as the second parameter. + match(*bufio.Reader) (Target, string) +} + +func (p *Proxy) netListen() func(net, laddr string) (net.Listener, error) { + if p.ListenFunc != nil { + return p.ListenFunc + } + return net.Listen +} + +func (p *Proxy) configFor(ipPort string) *config { + if p.configs == nil { + p.configs = make(map[string]*config) + } + if p.configs[ipPort] == nil { + p.configs[ipPort] = &config{} + } + return p.configs[ipPort] +} + +func (p *Proxy) addRoute(ipPort string, r route) { + cfg := p.configFor(ipPort) + cfg.routes = append(cfg.routes, r) +} + +// AddRoute appends an always-matching route to the ipPort listener, +// directing any connection to dest. +// +// This is generally used as either the only rule (for simple TCP +// proxies), or as the final fallback rule for an ipPort. +// +// The ipPort is any valid net.Listen TCP address. +func (p *Proxy) AddRoute(ipPort string, dest Target) { + p.addRoute(ipPort, fixedTarget{dest}) +} + +func (p *Proxy) setRoutes(ipPort string, targets []Target) { + var routes []route + for _, target := range targets { + routes = append(routes, fixedTarget{target}) + } + + cfg := p.configFor(ipPort) + cfg.routes = routes +} + +// SetRoutes replaces routes for the ipPort. +// +// It's possible that the old routes are still used once after this +// function is called. If an empty slice is passed, the routes are +// preserved in order to avoid an infinite loop. +func (p *Proxy) SetRoutes(ipPort string, targets []Target) { + if len(targets) == 0 { + return + } + p.setRoutes(ipPort, targets) +} + +type fixedTarget struct { + t Target +} + +func (m fixedTarget) match(*bufio.Reader) (Target, string) { return m.t, "" } + +// Run is calls Start, and then Wait. +// +// It blocks until there's an error. The return value is always +// non-nil. +func (p *Proxy) Run() error { + if err := p.Start(); err != nil { + return err + } + return p.Wait() +} + +// Wait waits for the Proxy to finish running. Currently this can only +// happen if a Listener is closed, or Close is called on the proxy. +// +// It is only valid to call Wait after a successful call to Start. +func (p *Proxy) Wait() error { + <-p.donec + return p.err +} + +// Close closes all the proxy's self-opened listeners. +func (p *Proxy) Close() error { + for _, c := range p.lns { + c.Close() + } + return nil +} + +// Start creates a TCP listener for each unique ipPort from the +// previously created routes and starts the proxy. It returns any +// error from starting listeners. +// +// If it returns a non-nil error, any successfully opened listeners +// are closed. +func (p *Proxy) Start() error { + if p.donec != nil { + return errors.New("already started") + } + p.donec = make(chan struct{}) + errc := make(chan error, len(p.configs)) + p.lns = make([]net.Listener, 0, len(p.configs)) + for ipPort, config := range p.configs { + ln, err := p.netListen()("tcp", ipPort) + if err != nil { + p.Close() + return err + } + p.lns = append(p.lns, ln) + p.routesChan = make(chan route) + go p.serveListener(errc, ln, config) + } + go p.awaitFirstError(errc) + return nil +} + +func (p *Proxy) awaitFirstError(errc <-chan error) { + p.err = <-errc + close(p.donec) +} + +func (p *Proxy) serveListener(ret chan<- error, ln net.Listener, cfg *config) { + go p.roundRobin(cfg) + for { + c, err := ln.Accept() + if err != nil { + ret <- err + return + } + p.serveConn(c) + } +} + +// serveConn runs in its own goroutine and matches c against routes. +// It returns whether it matched purely for testing. +func (p *Proxy) serveConn(c net.Conn) bool { + br := bufio.NewReader(c) + for route := range p.routesChan { + if target, hostName := route.match(br); target != nil { + if n := br.Buffered(); n > 0 { + peeked, _ := br.Peek(br.Buffered()) + c = &Conn{ + HostName: hostName, + Peeked: peeked, + Conn: c, + } + } + target.HandleConn(c) + return true + } + } + // TODO: hook for this? + log.Printf("tcpproxy: no routes matched conn %v/%v; closing", c.RemoteAddr().String(), c.LocalAddr().String()) + c.Close() + return false +} + +// roundRobin writes to a channel the next route to use. +func (p *Proxy) roundRobin(cfg *config) { + for { + for _, route := range cfg.routes { + p.routesChan <- route + } + } +} + +// Conn is an incoming connection that has had some bytes read from it +// to determine how to route the connection. The Read method stitches +// the peeked bytes and unread bytes back together. +type Conn struct { + // HostName is the hostname field that was sent to the request router. + // In the case of TLS, this is the SNI header, in the case of HTTPHost + // route, it will be the host header. In the case of a fixed + // route, i.e. those created with AddRoute(), this will always be + // empty. This can be useful in the case where further routing decisions + // need to be made in the Target impementation. + HostName string + + // Peeked are the bytes that have been read from Conn for the + // purposes of route matching, but have not yet been consumed + // by Read calls. It set to nil by Read when fully consumed. + Peeked []byte + + // Conn is the underlying connection. + // It can be type asserted against *net.TCPConn or other types + // as needed. It should not be read from directly unless + // Peeked is nil. + net.Conn +} + +func (c *Conn) Read(p []byte) (n int, err error) { + if len(c.Peeked) > 0 { + n = copy(p, c.Peeked) + c.Peeked = c.Peeked[n:] + if len(c.Peeked) == 0 { + c.Peeked = nil + } + return n, nil + } + return c.Conn.Read(p) +} + +// Target is what an incoming matched connection is sent to. +type Target interface { + // HandleConn is called when an incoming connection is + // matched. After the call to HandleConn, the tcpproxy + // package never touches the conn again. Implementations are + // responsible for closing the connection when needed. + // + // The concrete type of conn will be of type *Conn if any + // bytes have been consumed for the purposes of route + // matching. + HandleConn(net.Conn) +} + +// To is shorthand way of writing &tcpproxy.DialProxy{Addr: addr}. +func To(addr string) *DialProxy { + return &DialProxy{Addr: addr} +} + +// DialProxy implements Target by dialing a new connection to Addr +// and then proxying data back and forth. +// +// The To func is a shorthand way of creating a DialProxy. +type DialProxy struct { + // Addr is the TCP address to proxy to. + Addr string + + // KeepAlivePeriod sets the period between TCP keep alives. + // If zero, a default is used. To disable, use a negative number. + // The keep-alive is used for both the client connection and + KeepAlivePeriod time.Duration + + // DialTimeout optionally specifies a dial timeout. + // If zero, a default is used. + // If negative, the timeout is disabled. + DialTimeout time.Duration + + // DialContext optionally specifies an alternate dial function + // for TCP targets. If nil, the standard + // net.Dialer.DialContext method is used. + DialContext func(ctx context.Context, network, address string) (net.Conn, error) + + // OnDialError optionally specifies an alternate way to handle errors dialing Addr. + // If nil, the error is logged and src is closed. + // If non-nil, src is not closed automatically. + OnDialError func(src net.Conn, dstDialErr error) + + // ProxyProtocolVersion optionally specifies the version of + // HAProxy's PROXY protocol to use. The PROXY protocol provides + // connection metadata to the DialProxy target, via a header + // inserted ahead of the client's traffic. The DialProxy target + // must explicitly support and expect the PROXY header; there is + // no graceful downgrade. + // If zero, no PROXY header is sent. Currently, version 1 is supported. + ProxyProtocolVersion int +} + +// UnderlyingConn returns c.Conn if c of type *Conn, +// otherwise it returns c. +func UnderlyingConn(c net.Conn) net.Conn { + if wrap, ok := c.(*Conn); ok { + return wrap.Conn + } + return c +} + +func tcpConn(c net.Conn) (t *net.TCPConn, ok bool) { + if c, ok := UnderlyingConn(c).(*net.TCPConn); ok { + return c, ok + } + if c, ok := c.(*net.TCPConn); ok { + return c, ok + } + return nil, false +} + +func goCloseConn(c net.Conn) { go c.Close() } + +func closeRead(c net.Conn) { + if c, ok := tcpConn(c); ok { + _ = c.CloseRead() + } +} + +func closeWrite(c net.Conn) { + if c, ok := tcpConn(c); ok { + _ = c.CloseWrite() + } +} + +// HandleConn implements the Target interface. +func (dp *DialProxy) HandleConn(src net.Conn) { + ctx := context.Background() + var cancel context.CancelFunc + if dp.DialTimeout >= 0 { + ctx, cancel = context.WithTimeout(ctx, dp.dialTimeout()) + } + dst, err := dp.dialContext()(ctx, "tcp", dp.Addr) + if cancel != nil { + cancel() + } + if err != nil { + dp.onDialError()(src, err) + return + } + defer goCloseConn(dst) + + if err = dp.sendProxyHeader(dst, src); err != nil { + dp.onDialError()(src, err) + return + } + defer goCloseConn(src) + + if ka := dp.keepAlivePeriod(); ka > 0 { + for _, c := range []net.Conn{src, dst} { + if c, ok := tcpConn(c); ok { + _ = c.SetKeepAlive(true) + _ = c.SetKeepAlivePeriod(ka) + } + } + } + + errc := make(chan error, 2) + go proxyCopy(errc, src, dst) + go proxyCopy(errc, dst, src) + <-errc + <-errc +} + +func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error { + switch dp.ProxyProtocolVersion { + case 0: + return nil + case 1: + var srcAddr, dstAddr *net.TCPAddr + if a, ok := src.RemoteAddr().(*net.TCPAddr); ok { + srcAddr = a + } + if a, ok := src.LocalAddr().(*net.TCPAddr); ok { + dstAddr = a + } + + if srcAddr == nil || dstAddr == nil { + _, err := io.WriteString(w, "PROXY UNKNOWN\r\n") + return err + } + + family := "TCP4" + if srcAddr.IP.To4() == nil { + family = "TCP6" + } + _, err := fmt.Fprintf(w, "PROXY %s %s %s %d %d\r\n", family, srcAddr.IP, dstAddr.IP, srcAddr.Port, dstAddr.Port) + return err + default: + return fmt.Errorf("PROXY protocol version %d not supported", dp.ProxyProtocolVersion) + } +} + +// proxyCopy is the function that copies bytes around. +// It's a named function instead of a func literal so users get +// named goroutines in debug goroutine stack dumps. +func proxyCopy(errc chan<- error, dst, src net.Conn) { + defer closeRead(src) + defer closeWrite(dst) + + // Before we unwrap src and/or dst, copy any buffered data. + if wc, ok := src.(*Conn); ok && len(wc.Peeked) > 0 { + if _, err := dst.Write(wc.Peeked); err != nil { + errc <- err + return + } + wc.Peeked = nil + } + + // Unwrap the src and dst from *Conn to *net.TCPConn so Go + // 1.11's splice optimization kicks in. + src = UnderlyingConn(src) + dst = UnderlyingConn(dst) + + _, err := io.Copy(dst, src) + errc <- err +} + +func (dp *DialProxy) keepAlivePeriod() time.Duration { + if dp.KeepAlivePeriod != 0 { + return dp.KeepAlivePeriod + } + return time.Minute +} + +func (dp *DialProxy) dialTimeout() time.Duration { + if dp.DialTimeout > 0 { + return dp.DialTimeout + } + return 10 * time.Second +} + +var defaultDialer = new(net.Dialer) + +func (dp *DialProxy) dialContext() func(ctx context.Context, network, address string) (net.Conn, error) { + if dp.DialContext != nil { + return dp.DialContext + } + return defaultDialer.DialContext +} + +func (dp *DialProxy) onDialError() func(src net.Conn, dstDialErr error) { + if dp.OnDialError != nil { + return dp.OnDialError + } + return func(src net.Conn, dstDialErr error) { + log.Printf("tcpproxy: for incoming conn %v, error dialing %q: %v", src.RemoteAddr().String(), dp.Addr, dstDialErr) + src.Close() + } +} diff --git a/pkg/component/controller/cplb/tcpproxy/tcpproxy_test.go b/pkg/component/controller/cplb/tcpproxy/tcpproxy_test.go new file mode 100644 index 000000000000..777a5c95e4ed --- /dev/null +++ b/pkg/component/controller/cplb/tcpproxy/tcpproxy_test.go @@ -0,0 +1,226 @@ +// Copyright 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Modifications made by Mirantis Inc., 2024. +// Copyright 2017 Google Inc. +// +// Copyright 2024 Mirantis, Inc. + +package tcpproxy + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "testing" +) + +func TestProxyStartNone(t *testing.T) { + var p Proxy + if err := p.Start(); err != nil { + t.Fatal(err) + } +} + +func newLocalListener(t *testing.T) net.Listener { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + ln, err = net.Listen("tcp", "[::1]:0") + if err != nil { + t.Fatal(err) + } + } + return ln +} + +const testFrontAddr = "1.2.3.4:567" + +func testListenFunc(t *testing.T, ln net.Listener) func(network, laddr string) (net.Listener, error) { + return func(network, laddr string) (net.Listener, error) { + if network != "tcp" { + t.Errorf("got Listen call with network %q, not tcp", network) + return nil, errors.New("invalid network") + } + if laddr != testFrontAddr { + t.Fatalf("got Listen call with laddr %q, want %q", laddr, testFrontAddr) + panic("bogus address") + } + return ln, nil + } +} + +func testProxy(t *testing.T, front net.Listener) *Proxy { + return &Proxy{ + ListenFunc: testListenFunc(t, front), + } +} + +func TestBufferedClose(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + back := newLocalListener(t) + defer back.Close() + + p := testProxy(t, front) + p.AddRoute(testFrontAddr, To(back.Addr().String())) + if err := p.Start(); err != nil { + t.Fatal(err) + } + + toFront, err := net.Dial("tcp", front.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer toFront.Close() + + fromProxy, err := back.Accept() + if err != nil { + t.Fatal(err) + } + defer fromProxy.Close() + const msg = "message" + if _, err := io.WriteString(toFront, msg); err != nil { + t.Fatal(err) + } + // actively close toFront, the write should still make to the back. + toFront.Close() + + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(fromProxy, buf); err != nil { + t.Fatal(err) + } + if string(buf) != msg { + t.Fatalf("got %q; want %q", buf, msg) + } +} + +func TestProxyAlwaysMatch(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + back := newLocalListener(t) + defer back.Close() + + p := testProxy(t, front) + p.AddRoute(testFrontAddr, To(back.Addr().String())) + if err := p.Start(); err != nil { + t.Fatal(err) + } + + toFront, err := net.Dial("tcp", front.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer toFront.Close() + + fromProxy, err := back.Accept() + if err != nil { + t.Fatal(err) + } + defer fromProxy.Close() + const msg = "message" + _, _ = io.WriteString(toFront, msg) + + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(fromProxy, buf); err != nil { + t.Fatal(err) + } + if string(buf) != msg { + t.Fatalf("got %q; want %q", buf, msg) + } +} + +func TestProxyPROXYOut(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + back := newLocalListener(t) + defer back.Close() + + p := testProxy(t, front) + p.AddRoute(testFrontAddr, &DialProxy{ + Addr: back.Addr().String(), + ProxyProtocolVersion: 1, + }) + if err := p.Start(); err != nil { + t.Fatal(err) + } + + toFront, err := net.Dial("tcp", front.Addr().String()) + if err != nil { + t.Fatal(err) + } + + _, _ = io.WriteString(toFront, "foo") + toFront.Close() + + fromProxy, err := back.Accept() + if err != nil { + t.Fatal(err) + } + + bs, err := ioutil.ReadAll(fromProxy) + if err != nil { + t.Fatal(err) + } + + want := fmt.Sprintf("PROXY TCP4 %s %s %d %d\r\nfoo", toFront.LocalAddr().(*net.TCPAddr).IP, toFront.RemoteAddr().(*net.TCPAddr).IP, toFront.LocalAddr().(*net.TCPAddr).Port, toFront.RemoteAddr().(*net.TCPAddr).Port) + if string(bs) != want { + t.Fatalf("got %q; want %q", bs, want) + } +} + +func TestSetRoutes(t *testing.T) { + + var p Proxy + ipPort := ":8080" + p.AddRoute(ipPort, To("127.0.0.2:8080")) + cfg := p.configFor(ipPort) + + expectedAddrsList := [][]string{ + {"127.0.0.1:80"}, + {"127.0.0.1:80", "127.0.0.1:443"}, + {}, + {"127.0.0.1:80"}, + } + + for _, expectedAddrs := range expectedAddrsList { + p.setRoutes(ipPort, stringsToTargets(expectedAddrs)) + if !equalRoutes(cfg.routes, expectedAddrs) { + t.Fatalf("got %v; want %v", cfg.routes, expectedAddrs) + } + } +} + +func stringsToTargets(s []string) []Target { + targets := make([]Target, len(s)) + for i, v := range s { + targets[i] = To(v) + } + + return targets +} +func equalRoutes(routes []route, expectedAddrs []string) bool { + if len(routes) != len(expectedAddrs) { + return false + } + + for i := range routes { + addr := routes[i].(fixedTarget).t.(*DialProxy).Addr + if addr != expectedAddrs[i] { + return false + } + } + return true +}