diff --git a/pkg/component/controller/cplb/cplb_linux.go b/pkg/component/controller/cplb/cplb_linux.go index 04d7fcb3378f..72d0afd06f71 100644 --- a/pkg/component/controller/cplb/cplb_linux.go +++ b/pkg/component/controller/cplb/cplb_linux.go @@ -349,7 +349,7 @@ func (k *Keepalived) watchReconcilerUpdatesReverseProxy() error { k.proxy = tcpproxy.Proxy{} // We don't know how long until we get the first update, so initially we // forward everything to localhost - k.proxy.AddRoute(fmt.Sprintf(":%d", k.Config.UserSpaceProxyPort), tcpproxy.To(fmt.Sprintf("127.0.0.1:%d", k.APIPort))) + k.proxy.SetRoutes(fmt.Sprintf(":%d", k.Config.UserSpaceProxyPort), []tcpproxy.Route{tcpproxy.To(fmt.Sprintf("127.0.0.1:%d", k.APIPort))}) if err := k.proxy.Start(); err != nil { return fmt.Errorf("failed to start proxy: %w", err) @@ -372,11 +372,15 @@ func (k *Keepalived) watchReconcilerUpdatesReverseProxy() error { } func (k *Keepalived) setProxyRoutes() { - routes := []tcpproxy.Target{} + routes := []tcpproxy.Route{} for _, addr := range k.reconciler.GetIPs() { routes = append(routes, tcpproxy.To(fmt.Sprintf("%s:%d", addr, k.APIPort))) } + if len(routes) == 0 { + k.log.Error("No API servers available, leave previous configuration") + return + } k.proxy.SetRoutes(fmt.Sprintf(":%d", k.Config.UserSpaceProxyPort), routes) } diff --git a/pkg/component/controller/cplb/tcpproxy/tcpproxy.go b/pkg/component/controller/cplb/tcpproxy/tcpproxy.go index 23b4b7fc52bd..a6fa16d11103 100644 --- a/pkg/component/controller/cplb/tcpproxy/tcpproxy.go +++ b/pkg/component/controller/cplb/tcpproxy/tcpproxy.go @@ -27,9 +27,11 @@ import ( "errors" "fmt" "io" - "log" "net" + "sync" "time" + + "github.com/sirupsen/logrus" ) // Proxy is a proxy. Its zero value is a valid proxy that does @@ -38,12 +40,13 @@ import ( // The order that routes are added in matters; each is matched in the order // registered. type Proxy struct { + mux sync.RWMutex configs map[string]*config // ip:port => config lns []net.Listener donec chan struct{} // closed before err err error // any error from listening - routesChan chan route + connNumber int // connection number counter, used for round robin // ListenFunc optionally specifies an alternate listen // function. If nil, net.Dial is used. @@ -56,22 +59,7 @@ type Matcher func(ctx context.Context, hostname string) bool // config contains the proxying state for one listener. type config struct { - routes []route -} - -// A route matches a connection to a target. -type route interface { - // match examines the initial bytes of a connection, looking for a - // match. If a match is found, match returns a non-nil Target to - // which the stream should be proxied. match returns nil if the - // connection doesn't match. - // - // match must not consume bytes from the given bufio.Reader, it - // can only Peek. - // - // If an sni or host header was parsed successfully, that will be - // returned as the second parameter. - match(*bufio.Reader) (Target, string) + routes []Route } func (p *Proxy) netListen() func(net, laddr string) (net.Listener, error) { @@ -91,28 +79,7 @@ func (p *Proxy) configFor(ipPort string) *config { return p.configs[ipPort] } -func (p *Proxy) addRoute(ipPort string, r route) { - cfg := p.configFor(ipPort) - cfg.routes = append(cfg.routes, r) -} - -// AddRoute appends an always-matching route to the ipPort listener, -// directing any connection to dest. -// -// This is generally used as either the only rule (for simple TCP -// proxies), or as the final fallback rule for an ipPort. -// -// The ipPort is any valid net.Listen TCP address. -func (p *Proxy) AddRoute(ipPort string, dest Target) { - p.addRoute(ipPort, fixedTarget{dest}) -} - -func (p *Proxy) setRoutes(ipPort string, targets []Target) { - var routes []route - for _, target := range targets { - routes = append(routes, fixedTarget{target}) - } - +func (p *Proxy) setRoutes(ipPort string, routes []Route) { cfg := p.configFor(ipPort) cfg.routes = routes } @@ -122,19 +89,15 @@ func (p *Proxy) setRoutes(ipPort string, targets []Target) { // It's possible that the old routes are still used once after this // function is called. If an empty slice is passed, the routes are // preserved in order to avoid an infinite loop. -func (p *Proxy) SetRoutes(ipPort string, targets []Target) { +func (p *Proxy) SetRoutes(ipPort string, targets []Route) { + p.mux.Lock() + defer p.mux.Unlock() if len(targets) == 0 { - return + panic("SetRoutes with empty targets") } p.setRoutes(ipPort, targets) } -type fixedTarget struct { - t Target -} - -func (m fixedTarget) match(*bufio.Reader) (Target, string) { return m.t, "" } - // Run is calls Start, and then Wait. // // It blocks until there's an error. The return value is always @@ -183,7 +146,6 @@ func (p *Proxy) Start() error { return err } p.lns = append(p.lns, ln) - p.routesChan = make(chan route) go p.serveListener(errc, ln, config) } go p.awaitFirstError(errc) @@ -196,48 +158,35 @@ func (p *Proxy) awaitFirstError(errc <-chan error) { } func (p *Proxy) serveListener(ret chan<- error, ln net.Listener, cfg *config) { - go p.roundRobin(cfg) for { c, err := ln.Accept() if err != nil { ret <- err return } - go p.serveConn(c) + go p.serveConn(c, cfg) } } // serveConn runs in its own goroutine and matches c against routes. // It returns whether it matched purely for testing. -func (p *Proxy) serveConn(c net.Conn) bool { +func (p *Proxy) serveConn(c net.Conn, cfg *config) bool { br := bufio.NewReader(c) - for route := range p.routesChan { - if target, hostName := route.match(br); target != nil { - if n := br.Buffered(); n > 0 { - peeked, _ := br.Peek(br.Buffered()) - c = &Conn{ - HostName: hostName, - Peeked: peeked, - Conn: c, - } - } - target.HandleConn(c) - return true - } - } - // TODO: hook for this? - log.Printf("tcpproxy: no routes matched conn %v/%v; closing", c.RemoteAddr().String(), c.LocalAddr().String()) - c.Close() - return false -} -// roundRobin writes to a channel the next route to use. -func (p *Proxy) roundRobin(cfg *config) { - for { - for _, route := range cfg.routes { - p.routesChan <- route + p.mux.RLock() + p.connNumber++ + route := cfg.routes[p.connNumber%(len(cfg.routes))] + p.mux.RUnlock() + + if n := br.Buffered(); n > 0 { + peeked, _ := br.Peek(br.Buffered()) + c = &Conn{ + Peeked: peeked, + Conn: c, } } + route.HandleConn(c) + return true } // Conn is an incoming connection that has had some bytes read from it @@ -276,29 +225,17 @@ func (c *Conn) Read(p []byte) (n int, err error) { return c.Conn.Read(p) } -// Target is what an incoming matched connection is sent to. -type Target interface { - // HandleConn is called when an incoming connection is - // matched. After the call to HandleConn, the tcpproxy - // package never touches the conn again. Implementations are - // responsible for closing the connection when needed. - // - // The concrete type of conn will be of type *Conn if any - // bytes have been consumed for the purposes of route - // matching. - HandleConn(net.Conn) -} - // To is shorthand way of writing &tcpproxy.DialProxy{Addr: addr}. -func To(addr string) *DialProxy { - return &DialProxy{Addr: addr} +func To(addr string) Route { + return Route{Addr: addr} } -// DialProxy implements Target by dialing a new connection to Addr +// Route is what an incoming connection is sent to. +// It handles them by dialing a new connection to Addr // and then proxying data back and forth. // -// The To func is a shorthand way of creating a DialProxy. -type DialProxy struct { +// The To func is a shorthand way of creating a Route. +type Route struct { // Addr is the TCP address to proxy to. Addr string @@ -366,29 +303,29 @@ func closeWrite(c net.Conn) { } // HandleConn implements the Target interface. -func (dp *DialProxy) HandleConn(src net.Conn) { +func (r *Route) HandleConn(src net.Conn) { ctx := context.Background() var cancel context.CancelFunc - if dp.DialTimeout >= 0 { - ctx, cancel = context.WithTimeout(ctx, dp.dialTimeout()) + if r.DialTimeout >= 0 { + ctx, cancel = context.WithTimeout(ctx, r.dialTimeout()) } - dst, err := dp.dialContext()(ctx, "tcp", dp.Addr) + dst, err := r.dialContext()(ctx, "tcp", r.Addr) if cancel != nil { cancel() } if err != nil { - dp.onDialError()(src, err) + r.onDialError()(src, err) return } defer goCloseConn(dst) - if err = dp.sendProxyHeader(dst, src); err != nil { - dp.onDialError()(src, err) + if err = r.sendProxyHeader(dst, src); err != nil { + r.onDialError()(src, err) return } defer goCloseConn(src) - if ka := dp.keepAlivePeriod(); ka > 0 { + if ka := r.keepAlivePeriod(); ka > 0 { for _, c := range []net.Conn{src, dst} { if c, ok := tcpConn(c); ok { _ = c.SetKeepAlive(true) @@ -404,8 +341,8 @@ func (dp *DialProxy) HandleConn(src net.Conn) { <-errc } -func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error { - switch dp.ProxyProtocolVersion { +func (r *Route) sendProxyHeader(w io.Writer, src net.Conn) error { + switch r.ProxyProtocolVersion { case 0: return nil case 1: @@ -429,7 +366,7 @@ func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error { _, err := fmt.Fprintf(w, "PROXY %s %s %s %d %d\r\n", family, srcAddr.IP, dstAddr.IP, srcAddr.Port, dstAddr.Port) return err default: - return fmt.Errorf("PROXY protocol version %d not supported", dp.ProxyProtocolVersion) + return fmt.Errorf("PROXY protocol version %d not supported", r.ProxyProtocolVersion) } } @@ -458,35 +395,35 @@ func proxyCopy(errc chan<- error, dst, src net.Conn) { errc <- err } -func (dp *DialProxy) keepAlivePeriod() time.Duration { - if dp.KeepAlivePeriod != 0 { - return dp.KeepAlivePeriod +func (r *Route) keepAlivePeriod() time.Duration { + if r.KeepAlivePeriod != 0 { + return r.KeepAlivePeriod } return time.Minute } -func (dp *DialProxy) dialTimeout() time.Duration { - if dp.DialTimeout > 0 { - return dp.DialTimeout +func (r *Route) dialTimeout() time.Duration { + if r.DialTimeout > 0 { + return r.DialTimeout } return 10 * time.Second } var defaultDialer = new(net.Dialer) -func (dp *DialProxy) dialContext() func(ctx context.Context, network, address string) (net.Conn, error) { - if dp.DialContext != nil { - return dp.DialContext +func (r *Route) dialContext() func(ctx context.Context, network, address string) (net.Conn, error) { + if r.DialContext != nil { + return r.DialContext } return defaultDialer.DialContext } -func (dp *DialProxy) onDialError() func(src net.Conn, dstDialErr error) { - if dp.OnDialError != nil { - return dp.OnDialError +func (r *Route) onDialError() func(src net.Conn, dstDialErr error) { + if r.OnDialError != nil { + return r.OnDialError } return func(src net.Conn, dstDialErr error) { - log.Printf("tcpproxy: for incoming conn %v, error dialing %q: %v", src.RemoteAddr().String(), dp.Addr, dstDialErr) + logrus.WithFields(logrus.Fields{"component": "tcpproxy"}).Errorf("for incoming conn %v, error dialing %q: %v", src.RemoteAddr().String(), r.Addr, dstDialErr) src.Close() } } diff --git a/pkg/component/controller/cplb/tcpproxy/tcpproxy_test.go b/pkg/component/controller/cplb/tcpproxy/tcpproxy_test.go index 777a5c95e4ed..ff9f4be6b557 100644 --- a/pkg/component/controller/cplb/tcpproxy/tcpproxy_test.go +++ b/pkg/component/controller/cplb/tcpproxy/tcpproxy_test.go @@ -75,7 +75,7 @@ func TestBufferedClose(t *testing.T) { defer back.Close() p := testProxy(t, front) - p.AddRoute(testFrontAddr, To(back.Addr().String())) + p.SetRoutes(testFrontAddr, []Route{To(back.Addr().String())}) if err := p.Start(); err != nil { t.Fatal(err) } @@ -114,7 +114,7 @@ func TestProxyAlwaysMatch(t *testing.T) { defer back.Close() p := testProxy(t, front) - p.AddRoute(testFrontAddr, To(back.Addr().String())) + p.setRoutes(testFrontAddr, []Route{To(back.Addr().String())}) if err := p.Start(); err != nil { t.Fatal(err) } @@ -149,10 +149,10 @@ func TestProxyPROXYOut(t *testing.T) { defer back.Close() p := testProxy(t, front) - p.AddRoute(testFrontAddr, &DialProxy{ + p.SetRoutes(testFrontAddr, []Route{{ Addr: back.Addr().String(), ProxyProtocolVersion: 1, - }) + }}) if err := p.Start(); err != nil { t.Fatal(err) } @@ -185,7 +185,7 @@ func TestSetRoutes(t *testing.T) { var p Proxy ipPort := ":8080" - p.AddRoute(ipPort, To("127.0.0.2:8080")) + p.setRoutes(ipPort, []Route{To("127.0.0.2:8080")}) cfg := p.configFor(ipPort) expectedAddrsList := [][]string{ @@ -203,21 +203,21 @@ func TestSetRoutes(t *testing.T) { } } -func stringsToTargets(s []string) []Target { - targets := make([]Target, len(s)) +func stringsToTargets(s []string) []Route { + targets := make([]Route, len(s)) for i, v := range s { targets[i] = To(v) } return targets } -func equalRoutes(routes []route, expectedAddrs []string) bool { +func equalRoutes(routes []Route, expectedAddrs []string) bool { if len(routes) != len(expectedAddrs) { return false } for i := range routes { - addr := routes[i].(fixedTarget).t.(*DialProxy).Addr + addr := routes[i].Addr if addr != expectedAddrs[i] { return false }