diff --git a/go.mod b/go.mod index 01a079d844e..830721ccf25 100644 --- a/go.mod +++ b/go.mod @@ -174,4 +174,4 @@ require ( sigs.k8s.io/structured-merge-diff/v4 v4.3.0 // indirect ) -replace github.com/gocql/gocql => github.com/scylladb/gocql v1.7.3 +replace github.com/gocql/gocql => github.com/scylladb/gocql v1.12.0 diff --git a/go.sum b/go.sum index 2ebc51cb9a8..16d623be136 100644 --- a/go.sum +++ b/go.sum @@ -355,8 +355,8 @@ github.com/scylladb/go-reflectx v1.0.1 h1:b917wZM7189pZdlND9PbIJ6NQxfDPfBvUaQ7cj github.com/scylladb/go-reflectx v1.0.1/go.mod h1:rWnOfDIRWBGN0miMLIcoPt/Dhi2doCMZqwMCJ3KupFc= github.com/scylladb/go-set v1.0.2 h1:SkvlMCKhP0wyyct6j+0IHJkBkSZL+TDzZ4E7f7BCcRE= github.com/scylladb/go-set v1.0.2/go.mod h1:DkpGd78rljTxKAnTDPFqXSGxvETQnJyuSOQwsHycqfs= -github.com/scylladb/gocql v1.7.3 h1:tCZ44eA4SDC69SHgp1XUcEdWcXi5CQb+iaMOrpncwvI= -github.com/scylladb/gocql v1.7.3/go.mod h1:TA7opQwU+6t8LmGZr/oyudP4QhVj3ucqbtZ73Xu4ghY= +github.com/scylladb/gocql v1.12.0 h1:KaP25dC2Mu0H382M8KZmkQp1XuasgBG97bBhFeFKVyk= +github.com/scylladb/gocql v1.12.0/go.mod h1:ZLEJ0EVE5JhmtxIW2stgHq/v1P4fWap0qyyXSKyV8K0= github.com/scylladb/gocqlx/v2 v2.8.0 h1:f/oIgoEPjKDKd+RIoeHqexsIQVIbalVmT+axwvUqQUg= github.com/scylladb/gocqlx/v2 v2.8.0/go.mod h1:4/+cga34PVqjhgSoo5Nr2fX1MQIqZB5eCE5DK4xeDig= github.com/scylladb/scylladb-swagger-go-client v0.2.0 h1:WRzrS07NSQSwxkKz67UnzssOEtRb2Ri4yUEi/xOMLQ0= diff --git a/vendor/github.com/gocql/gocql/AUTHORS b/vendor/github.com/gocql/gocql/AUTHORS index b13434176a2..a6eeeabc434 100644 --- a/vendor/github.com/gocql/gocql/AUTHORS +++ b/vendor/github.com/gocql/gocql/AUTHORS @@ -126,3 +126,19 @@ Maxim Vladimirskiy Bogdan-Ciprian Rusu Yuto Doi Krishna Vadali +Jens-W. Schicke-Uffmann +Ondrej Polakovič +Sergei Karetnikov +Stefan Miklosovic +Adam Burk +Valerii Ponomarov +Neal Turett +Doug Schaapveld +Steven Seidman +Wojciech Przytuła +João Reis +Lauro Ramos Venancio +Dmitry Kropachev +Oliver Boyle +Jackson Fleming +Sylwia Szunejko diff --git a/vendor/github.com/gocql/gocql/CONTRIBUTING.md b/vendor/github.com/gocql/gocql/CONTRIBUTING.md index 093045a31d7..8c2df74b7f0 100644 --- a/vendor/github.com/gocql/gocql/CONTRIBUTING.md +++ b/vendor/github.com/gocql/gocql/CONTRIBUTING.md @@ -20,7 +20,7 @@ The following is a check list of requirements that need to be satisfied in order * The test coverage does not fall below the critical threshold (currently 64%) * The merge commit passes the regression test suite on Travis * `go fmt` has been applied to the submitted code -* Functional changes (i.e. new features or changed behavior) are appropriately documented, either as a godoc or in the README (non-functional changes such as bug fixes may not require documentation) +* Notable changes (i.e. new features or changed behavior, bugfixes) are appropriately documented in CHANGELOG.md, functional changes also in godoc If there are any requirements that can't be reasonably satisfied, please state this either on the pull request or as part of discussion on the mailing list. Where appropriate, the core team may apply discretion and make an exception to these requirements. diff --git a/vendor/github.com/gocql/gocql/README.md b/vendor/github.com/gocql/gocql/README.md index f4f5dd586b2..ae62c4f1e8e 100644 --- a/vendor/github.com/gocql/gocql/README.md +++ b/vendor/github.com/gocql/gocql/README.md @@ -15,6 +15,11 @@ There are open pull requests to merge the functionality to the upstream project: It also provides support for shard aware ports, a faster way to connect to all shards, details available in [blogpost](https://www.scylladb.com/2021/04/27/connect-faster-to-scylla-with-a-shard-aware-port/). +Sunsetting Model +---------------- + +In general, the gocql team will focus on supporting the current and previous versions of Go. gocql may still work with older versions of Go, but official support for these versions will have been sunset. + Installation ------------ diff --git a/vendor/github.com/gocql/gocql/cluster.go b/vendor/github.com/gocql/gocql/cluster.go index 140c5a5eaff..66206ea9b34 100644 --- a/vendor/github.com/gocql/gocql/cluster.go +++ b/vendor/github.com/gocql/gocql/cluster.go @@ -17,6 +17,8 @@ import ( type PoolConfig struct { // HostSelectionPolicy sets the policy for selecting which host to use for a // given query (default: RoundRobinHostPolicy()) + // It is not supported to use a single HostSelectionPolicy in multiple sessions + // (even if you close the old session before using in a new session). HostSelectionPolicy HostSelectionPolicy } @@ -49,13 +51,30 @@ type ClusterConfig struct { // versions the protocol selected is not defined (ie, it can be any of the supported in the cluster) ProtoVersion int - // Connection timeout (default: 600ms) + // Timeout limits the time spent on the client side while executing a query. + // Specifically, query or batch execution will return an error if the client does not receive a response + // from the server within the Timeout period. + // Timeout is also used to configure the read timeout on the underlying network connection. + // Client Timeout should always be higher than the request timeouts configured on the server, + // so that retries don't overload the server. + // Timeout has a default value of 11 seconds, which is higher than default server timeout for most query types. + // Timeout is not applied to requests during initial connection setup, see ConnectTimeout. Timeout time.Duration - // Initial connection timeout, used during initial dial to server (default: 600ms) - // ConnectTimeout is used to set up the default dialer and is ignored if Dialer or HostDialer is provided. + // ConnectTimeout limits the time spent during connection setup. + // During initial connection setup, internal queries, AUTH requests will return an error if the client + // does not receive a response within the ConnectTimeout period. + // ConnectTimeout is applied to the connection setup queries independently. + // ConnectTimeout also limits the duration of dialing a new TCP connection + // in case there is no Dialer nor HostDialer configured. + // ConnectTimeout has a default value of 11 seconds. ConnectTimeout time.Duration + // WriteTimeout limits the time the driver waits to write a request to a network connection. + // WriteTimeout should be lower than or equal to Timeout. + // WriteTimeout defaults to the value of Timeout. + WriteTimeout time.Duration + // Port used when dialing. // Default: 9042 Port int @@ -67,6 +86,11 @@ type ClusterConfig struct { // Default: 2 NumConns int + // Maximum number of inflight requests allowed per connection. + // Default: 32768 for CQL v3 and newer + // Default: 128 for older CQL versions + MaxRequestsPerConn int + // Default consistency level. // Default: Quorum Consistency Consistency @@ -191,6 +215,10 @@ type ClusterConfig struct { // Use it to collect metrics / stats from frames by providing an implementation of FrameHeaderObserver. FrameHeaderObserver FrameHeaderObserver + // StreamObserver will be notified of stream state changes. + // This can be used to track in-flight protocol requests and responses. + StreamObserver StreamObserver + // Default idempotence for queries DefaultIdempotence bool @@ -249,8 +277,8 @@ func NewCluster(hosts ...string) *ClusterConfig { cfg := &ClusterConfig{ Hosts: hosts, CQLVersion: "3.0.0", - Timeout: 600 * time.Millisecond, - ConnectTimeout: 600 * time.Millisecond, + Timeout: 11 * time.Second, + ConnectTimeout: 11 * time.Second, Port: 9042, NumConns: 2, Consistency: Quorum, diff --git a/vendor/github.com/gocql/gocql/conn.go b/vendor/github.com/gocql/gocql/conn.go index abb8b80f092..9a223f80a32 100644 --- a/vendor/github.com/gocql/gocql/conn.go +++ b/vendor/github.com/gocql/gocql/conn.go @@ -34,6 +34,7 @@ var ( "com.ericsson.bss.cassandra.ecaudit.auth.AuditAuthenticator", "com.scylladb.auth.SaslauthdAuthenticator", "com.scylladb.auth.TransitionalAuthenticator", + "com.instaclustr.cassandra.auth.InstaclustrPasswordAuthenticator", } ) @@ -93,13 +94,13 @@ func (p PasswordAuthenticator) Success(data []byte) error { // to true if no Config is set. Most users should set SslOptions.Config to a *tls.Config. // SslOptions and Config.InsecureSkipVerify interact as follows: // -// Config.InsecureSkipVerify | EnableHostVerification | Result -// Config is nil | false | do not verify host -// Config is nil | true | verify host -// false | false | verify host -// true | false | do not verify host -// false | true | verify host -// true | true | verify host +// Config.InsecureSkipVerify | EnableHostVerification | Result +// Config is nil | false | do not verify host +// Config is nil | true | verify host +// false | false | verify host +// true | false | do not verify host +// false | true | verify host +// true | true | verify host type SslOptions struct { *tls.Config @@ -122,6 +123,7 @@ type ConnConfig struct { ProtoVersion int CQLVersion string Timeout time.Duration + WriteTimeout time.Duration ConnectTimeout time.Duration Dialer Dialer HostDialer HostDialer @@ -166,17 +168,22 @@ var TimeoutLimit int64 = 0 type Conn struct { conn net.Conn r *bufio.Reader - w io.Writer + w contextWriter - timeout time.Duration - cfg *ConnConfig - frameObserver FrameHeaderObserver + timeout time.Duration + writeTimeout time.Duration + cfg *ConnConfig + frameObserver FrameHeaderObserver + streamObserver StreamObserver headerBuf [maxFrameHeaderSize]byte streams *streams.IDGenerator mu sync.Mutex - calls map[int]*callReq + // calls stores a map from stream ID to callReq. + // This map is protected by mu. + // calls should not be used when closed is true, calls is set to nil when closed=true. + calls map[int]*callReq errorHandler ConnErrorHandler compressor Compressor @@ -189,10 +196,13 @@ type Conn struct { supported map[string][]string scyllaSupported scyllaSupported cqlProtoExts []cqlProtocolExtension + isSchemaV2 bool session *Session - closed int32 + // true if connection close process for the connection started. + // closed is protected by mu. + closed bool ctx context.Context cancel context.CancelFunc @@ -262,6 +272,11 @@ func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg * return nil, err } + writeTimeout := cfg.Timeout + if cfg.WriteTimeout > 0 { + writeTimeout = cfg.WriteTimeout + } + ctx, cancel := context.WithCancel(ctx) c := &Conn{ conn: dialedHost.Conn, @@ -273,16 +288,21 @@ func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg * errorHandler: errorHandler, compressor: cfg.Compressor, session: s, - streams: streams.New(cfg.ProtoVersion), + streams: s.streamIDGenerator(cfg.ProtoVersion), host: host, + isSchemaV2: true, // Try using "system.peers_v2" until proven otherwise frameObserver: s.frameObserver, - w: &deadlineWriter{ - w: dialedHost.Conn, - timeout: cfg.Timeout, + w: &deadlineContextWriter{ + w: dialedHost.Conn, + timeout: writeTimeout, + semaphore: make(chan struct{}, 1), + quit: make(chan struct{}), }, - ctx: ctx, - cancel: cancel, - logger: cfg.logger(), + ctx: ctx, + cancel: cancel, + logger: cfg.logger(), + streamObserver: s.streamObserver, + writeTimeout: writeTimeout, } if err := c.init(ctx, dialedHost); err != nil { @@ -294,6 +314,13 @@ func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg * return c, nil } +func (s *Session) streamIDGenerator(protocol int) *streams.IDGenerator { + if s.cfg.MaxRequestsPerConn > 0 { + return streams.NewLimited(s.cfg.MaxRequestsPerConn) + } + return streams.New(protocol) +} + func (c *Conn) init(ctx context.Context, dialedHost *DialedHost) error { if c.session.cfg.AuthProvider != nil { var err error @@ -319,7 +346,7 @@ func (c *Conn) init(ctx context.Context, dialedHost *DialedHost) error { // dont coalesce startup frames if c.session.cfg.WriteCoalesceWaitTime > 0 && !c.cfg.disableCoalesce && !dialedHost.DisableCoalesce { - c.w = newWriteCoalescer(c.conn, c.timeout, c.session.cfg.WriteCoalesceWaitTime, ctx.Done()) + c.w = newWriteCoalescer(c.conn, c.writeTimeout, c.session.cfg.WriteCoalesceWaitTime, ctx.Done()) } go c.serve(ctx) @@ -329,7 +356,7 @@ func (c *Conn) init(ctx context.Context, dialedHost *DialedHost) error { } func (c *Conn) Write(p []byte) (n int, err error) { - return c.w.Write(p) + return c.w.writeContext(context.Background(), p) } func (c *Conn) Read(p []byte) (n int, err error) { @@ -405,7 +432,7 @@ func (s *startupCoordinator) setupConn(ctx context.Context) error { return nil } -func (s *startupCoordinator) write(ctx context.Context, frame frameWriter) (frame, error) { +func (s *startupCoordinator) write(ctx context.Context, frame frameBuilder) (frame, error) { select { case s.frameTicker <- struct{}{}: case <-ctx.Done(): @@ -441,7 +468,9 @@ func (s *startupCoordinator) options(ctx context.Context) error { func (s *startupCoordinator) startup(ctx context.Context) error { m := map[string]string{ - "CQL_VERSION": s.conn.cfg.CQLVersion, + "CQL_VERSION": s.conn.cfg.CQLVersion, + "DRIVER_NAME": driverName, + "DRIVER_VERSION": driverVersion, } if s.conn.compressor != nil { @@ -528,23 +557,38 @@ func (c *Conn) closeWithError(err error) { return } - if !atomic.CompareAndSwapInt32(&c.closed, 0, 1) { + c.mu.Lock() + if c.closed { + c.mu.Unlock() return } + c.closed = true - // we should attempt to deliver the error back to the caller if it - // exists + var callsToClose map[int]*callReq + + // We should attempt to deliver the error back to the caller if it + // exists. However, don't block c.mu while we are delivering the + // error to outstanding calls. if err != nil { - c.mu.Lock() - for _, req := range c.calls { - // we need to send the error to all waiting queries, put the state - // of this conn into not active so that it can not execute any queries. - select { - case req.resp <- err: - case <-req.timeout: - } + callsToClose = c.calls + // It is safe to change c.calls to nil. Nobody should use it after c.closed is set to true. + c.calls = nil + } + c.mu.Unlock() + + for _, req := range callsToClose { + // we need to send the error to all waiting queries. + select { + case req.resp <- callResp{err: err}: + case <-req.timeout: + } + if req.streamObserverContext != nil { + req.streamObserverEndOnce.Do(func() { + req.streamObserverContext.StreamAbandoned(ObservedStream{ + Host: c.host, + }) + }) } - c.mu.Unlock() } // if error was nil then unblock the quit channel @@ -679,8 +723,8 @@ func (c *Conn) recv(ctx context.Context) error { return fmt.Errorf("gocql: frame header stream is beyond call expected bounds: %d", head.stream) } else if head.stream == -1 { // TODO: handle cassandra event frames, we shouldnt get any currently - framer := newFramerWithExts(c, c, c.compressor, c.version, c.cqlProtoExts) - if err := framer.readFrame(&head); err != nil { + framer := newFramerWithExts(c.compressor, c.version, c.cqlProtoExts) + if err := framer.readFrame(c, &head); err != nil { return err } go c.session.handleEvent(framer) @@ -688,8 +732,8 @@ func (c *Conn) recv(ctx context.Context) error { } else if head.stream <= 0 { // reserved stream that we dont use, probably due to a protocol error // or a bug in Cassandra, this should be an error, parse it and return. - framer := newFramerWithExts(c, c, c.compressor, c.version, c.cqlProtoExts) - if err := framer.readFrame(&head); err != nil { + framer := newFramerWithExts(c.compressor, c.version, c.cqlProtoExts) + if err := framer.readFrame(c, &head); err != nil { return err } @@ -704,17 +748,23 @@ func (c *Conn) recv(ctx context.Context) error { } c.mu.Lock() + if c.closed { + c.mu.Unlock() + return ErrConnectionClosed + } call, ok := c.calls[head.stream] delete(c.calls, head.stream) c.mu.Unlock() - if call == nil || call.framer == nil || !ok { + if call == nil || !ok { c.logger.Printf("gocql: received response for stream which has no handler: header=%v\n", head) return c.discardFrame(head) } else if head.stream != call.streamID { panic(fmt.Sprintf("call has incorrect streamID: got %d expected %d", call.streamID, head.stream)) } - err = call.framer.readFrame(&head) + framer := newFramer(c.compressor, c.version) + + err = framer.readFrame(c, &head) if err != nil { // only net errors should cause the connection to be closed. Though // cassandra returning corrupt frames will be returned here as well. @@ -726,7 +776,7 @@ func (c *Conn) recv(ctx context.Context) error { // we either, return a response to the caller, the caller timedout, or the // connection has closed. Either way we should never block indefinatly here select { - case call.resp <- err: + case call.resp <- callResp{framer: framer, err: err}: case <-call.timeout: c.releaseStream(call) case <-ctx.Done(): @@ -741,6 +791,14 @@ func (c *Conn) releaseStream(call *callReq) { } c.streams.Clear(call.streamID) + + if call.streamObserverContext != nil { + call.streamObserverEndOnce.Do(func() { + call.streamObserverContext.StreamFinished(ObservedStream{ + Host: c.host, + }) + }) + } } func (c *Conn) handleTimeout() { @@ -750,152 +808,255 @@ func (c *Conn) handleTimeout() { } type callReq struct { - // could use a waitgroup but this allows us to do timeouts on the read/send - resp chan error - framer *framer - timeout chan struct{} // indicates to recv() that a call has timedout + // resp will receive the frame that was sent as a response to this stream. + resp chan callResp + timeout chan struct{} // indicates to recv() that a call has timed out streamID int // current stream in use timer *time.Timer + + // streamObserverContext is notified about events regarding this stream + streamObserverContext StreamObserverContext + + // streamObserverEndOnce ensures that either StreamAbandoned or StreamFinished is called, + // but not both. + streamObserverEndOnce sync.Once } -type deadlineWriter struct { - w interface { - SetWriteDeadline(time.Time) error - io.Writer - } +type callResp struct { + // framer is the response frame. + // May be nil if err is not nil. + framer *framer + // err is error encountered, if any. + err error +} + +// contextWriter is like io.Writer, but takes context as well. +type contextWriter interface { + // writeContext writes p to the connection. + // + // If ctx is canceled before we start writing p (e.g. during waiting while another write is currently in progress), + // p is not written and ctx.Err() is returned. Context is ignored after we start writing p (i.e. we don't interrupt + // blocked writes that are in progress) so that we always either write the full frame or not write it at all. + // + // It returns the number of bytes written from p (0 <= n <= len(p)) and any error that caused the write to stop + // early. writeContext must return a non-nil error if it returns n < len(p). writeContext must not modify the + // data in p, even temporarily. + writeContext(ctx context.Context, p []byte) (n int, err error) +} + +type deadlineWriter interface { + SetWriteDeadline(time.Time) error + io.Writer +} + +type deadlineContextWriter struct { + w deadlineWriter timeout time.Duration + // semaphore protects critical section for SetWriteDeadline/Write. + // It is a channel with capacity 1. + semaphore chan struct{} + + // quit closed once the connection is closed. + quit chan struct{} } -func (c *deadlineWriter) Write(p []byte) (int, error) { +// writeContext implements contextWriter. +func (c *deadlineContextWriter) writeContext(ctx context.Context, p []byte) (int, error) { + select { + case <-ctx.Done(): + return 0, ctx.Err() + case <-c.quit: + return 0, ErrConnectionClosed + case c.semaphore <- struct{}{}: + // acquired + } + + defer func() { + // release + <-c.semaphore + }() + if c.timeout > 0 { - c.w.SetWriteDeadline(time.Now().Add(c.timeout)) + err := c.w.SetWriteDeadline(time.Now().Add(c.timeout)) + if err != nil { + return 0, err + } } return c.w.Write(p) } -func newWriteCoalescer(conn net.Conn, timeout time.Duration, d time.Duration, quit <-chan struct{}) *writeCoalescer { +func newWriteCoalescer(conn deadlineWriter, writeTimeout, coalesceDuration time.Duration, + quit <-chan struct{}) *writeCoalescer { wc := &writeCoalescer{ - writeCh: make(chan struct{}), // TODO: could this be sync? - cond: sync.NewCond(&sync.Mutex{}), + writeCh: make(chan writeRequest), c: conn, quit: quit, - timeout: timeout, + timeout: writeTimeout, } - go wc.writeFlusher(d) + go wc.writeFlusher(coalesceDuration) return wc } type writeCoalescer struct { - c net.Conn + c deadlineWriter + + mu sync.Mutex quit <-chan struct{} - writeCh chan struct{} - running bool + writeCh chan writeRequest - // cond waits for the buffer to be flushed - cond *sync.Cond - buffers net.Buffers timeout time.Duration - // result of the write - err error -} - -func (w *writeCoalescer) flushLocked() { - w.running = false - if len(w.buffers) == 0 { - return - } - - if w.timeout > 0 { - w.c.SetWriteDeadline(time.Now().Add(w.timeout)) - } - - // Given we are going to do a fanout n is useless and according to - // the docs WriteTo should return 0 and err or bytes written and - // no error. - _, w.err = w.buffers.WriteTo(w.c) - if w.err != nil { - w.buffers = nil - } - w.cond.Broadcast() + testEnqueuedHook func() + testFlushedHook func() } -func (w *writeCoalescer) flush() { - w.cond.L.Lock() - w.flushLocked() - w.cond.L.Unlock() +type writeRequest struct { + // resultChan is a channel (with buffer size 1) where to send results of the write. + resultChan chan<- writeResult + // data to write. + data []byte } -func (w *writeCoalescer) stop() { - w.cond.L.Lock() - defer w.cond.L.Unlock() - - w.flushLocked() - // nil the channel out sends block forever on it - // instead of closing which causes a send on closed channel - // panic. - w.writeCh = nil +type writeResult struct { + n int + err error } -func (w *writeCoalescer) Write(p []byte) (int, error) { - w.cond.L.Lock() - - if !w.running { - select { - case w.writeCh <- struct{}{}: - w.running = true - case <-w.quit: - w.cond.L.Unlock() - return 0, io.EOF // TODO: better error here? - } +// writeContext implements contextWriter. +func (w *writeCoalescer) writeContext(ctx context.Context, p []byte) (int, error) { + resultChan := make(chan writeResult, 1) + wr := writeRequest{ + resultChan: resultChan, + data: p, } - w.buffers = append(w.buffers, p) - for len(w.buffers) != 0 { - w.cond.Wait() + select { + case <-ctx.Done(): + return 0, ctx.Err() + case <-w.quit: + return 0, io.EOF // TODO: better error here? + case w.writeCh <- wr: + // enqueued for writing } - err := w.err - w.cond.L.Unlock() - - if err != nil { - return 0, err + if w.testEnqueuedHook != nil { + w.testEnqueuedHook() } - return len(p), nil + + result := <-resultChan + return result.n, result.err } func (w *writeCoalescer) writeFlusher(interval time.Duration) { timer := time.NewTimer(interval) defer timer.Stop() - defer w.stop() if !timer.Stop() { <-timer.C } + w.writeFlusherImpl(timer.C, func() { timer.Reset(interval) }) +} + +func (w *writeCoalescer) writeFlusherImpl(timerC <-chan time.Time, resetTimer func()) { + running := false + + var buffers net.Buffers + var resultChans []chan<- writeResult + for { - // wait for a write to start the flush loop select { - case <-w.writeCh: + case req := <-w.writeCh: + buffers = append(buffers, req.data) + resultChans = append(resultChans, req.resultChan) + if !running { + // Start timer on first write. + resetTimer() + running = true + } case <-w.quit: + result := writeResult{ + n: 0, + err: io.EOF, // TODO: better error here? + } + // Unblock whoever was waiting. + for _, resultChan := range resultChans { + // resultChan has capacity 1, so it does not block. + resultChan <- result + } return + case <-timerC: + running = false + w.flush(resultChans, buffers) + buffers = nil + resultChans = nil + if w.testFlushedHook != nil { + w.testFlushedHook() + } } + } +} - timer.Reset(interval) - - select { - case <-w.quit: +func (w *writeCoalescer) flush(resultChans []chan<- writeResult, buffers net.Buffers) { + // Flush everything we have so far. + if w.timeout > 0 { + err := w.c.SetWriteDeadline(time.Now().Add(w.timeout)) + if err != nil { + for i := range resultChans { + resultChans[i] <- writeResult{ + n: 0, + err: err, + } + } return - case <-timer.C: } + } + // Copy buffers because WriteTo modifies buffers in-place. + buffers2 := make(net.Buffers, len(buffers)) + copy(buffers2, buffers) + n, err := buffers2.WriteTo(w.c) + // Writes of bytes before n succeeded, writes of bytes starting from n failed with err. + // Use n as remaining byte counter. + for i := range buffers { + if int64(len(buffers[i])) <= n { + // this buffer was fully written. + resultChans[i] <- writeResult{ + n: len(buffers[i]), + err: nil, + } + n -= int64(len(buffers[i])) + } else { + // this buffer was not (fully) written. + resultChans[i] <- writeResult{ + n: int(n), + err: err, + } + n = 0 + } + } +} - w.flush() +// addCall attempts to add a call to c.calls. +// It fails with error if the connection already started closing or if a call for the given stream +// already exists. +func (c *Conn) addCall(call *callReq) error { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return ErrConnectionClosed + } + existingCall := c.calls[call.streamID] + if existingCall != nil { + return fmt.Errorf("attempting to use stream already in use: %d -> %d", call.streamID, + existingCall.streamID) } + c.calls[call.streamID] = call + return nil } -func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*framer, error) { +func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*framer, error) { if ctxErr := ctx.Err(); ctxErr != nil { return nil, ctxErr } @@ -907,42 +1068,79 @@ func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*frame } // resp is basically a waiting semaphore protecting the framer - framer := newFramerWithExts(c, c, c.compressor, c.version, c.cqlProtoExts) + framer := newFramerWithExts(c.compressor, c.version, c.cqlProtoExts) call := &callReq{ - framer: framer, timeout: make(chan struct{}), streamID: stream, - resp: make(chan error), + resp: make(chan callResp), } - c.mu.Lock() - existingCall := c.calls[stream] - if existingCall == nil { - c.calls[stream] = call + if c.streamObserver != nil { + call.streamObserverContext = c.streamObserver.StreamContext(ctx) } - c.mu.Unlock() - if existingCall != nil { - return nil, fmt.Errorf("attempting to use stream already in use: %d -> %d", stream, existingCall.streamID) + if err := c.addCall(call); err != nil { + return nil, err } + // After this point, we need to either read from call.resp or close(call.timeout) + // since closeWithError can try to write a connection close error to call.resp. + // If we don't close(call.timeout) or read from call.resp, closeWithError can deadlock. + if tracer != nil { framer.trace() } - err := req.writeFrame(framer, stream) + if call.streamObserverContext != nil { + call.streamObserverContext.StreamStarted(ObservedStream{ + Host: c.host, + }) + } + + err := req.buildFrame(framer, stream) + if err != nil { + // closeWithError will block waiting for this stream to either receive a response + // or for us to timeout. + close(call.timeout) + // We failed to serialize the frame into a buffer. + // This should not affect the connection as we didn't write anything. We just free the current call. + c.mu.Lock() + if !c.closed { + delete(c.calls, call.streamID) + } + c.mu.Unlock() + // We need to release the stream after we remove the call from c.calls, otherwise the existingCall != nil + // check above could fail. + c.releaseStream(call) + return nil, err + } + + n, err := c.w.writeContext(ctx, framer.buf) if err != nil { // closeWithError will block waiting for this stream to either receive a response // or for us to timeout, close the timeout chan here. Im not entirely sure // but we should not get a response after an error on the write side. close(call.timeout) - // I think this is the correct thing to do, im not entirely sure. It is not - // ideal as readers might still get some data, but they probably wont. - // Here we need to be careful as the stream is not available and if all - // writes just timeout or fail then the pool might use this connection to - // send a frame on, with all the streams used up and not returned. - c.closeWithError(err) + if (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)) && n == 0 { + // We have not started to write this frame. + // Release the stream as no response can come from the server on the stream. + c.mu.Lock() + if !c.closed { + delete(c.calls, call.streamID) + } + c.mu.Unlock() + // We need to release the stream after we remove the call from c.calls, otherwise the existingCall != nil + // check above could fail. + c.releaseStream(call) + } else { + // I think this is the correct thing to do, im not entirely sure. It is not + // ideal as readers might still get some data, but they probably wont. + // Here we need to be careful as the stream is not available and if all + // writes just timeout or fail then the pool might use this connection to + // send a frame on, with all the streams used up and not returned. + c.closeWithError(err) + } return nil, err } @@ -970,9 +1168,9 @@ func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*frame } select { - case err := <-call.resp: + case resp := <-call.resp: close(call.timeout) - if err != nil { + if resp.err != nil { if !c.Closed() { // if the connection is closed then we cant release the stream, // this is because the request is still outstanding and we have @@ -980,8 +1178,21 @@ func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*frame // connection to close. c.releaseStream(call) } - return nil, err + return nil, resp.err + } + // dont release the stream if detect a timeout as another request can reuse + // that stream and get a response for the old request, which we have no + // easy way of detecting. + // + // Ensure that the stream is not released if there are potentially outstanding + // requests on the stream to prevent nil pointer dereferences in recv(). + defer c.releaseStream(call) + + if v := resp.framer.header.version.version(); v != c.version { + return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version) } + + return resp.framer, nil case <-timeoutCh: close(call.timeout) c.handleTimeout() @@ -990,22 +1201,55 @@ func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*frame close(call.timeout) return nil, ctx.Err() case <-c.ctx.Done(): + close(call.timeout) return nil, ErrConnectionClosed } +} - // dont release the stream if detect a timeout as another request can reuse - // that stream and get a response for the old request, which we have no - // easy way of detecting. - // - // Ensure that the stream is not released if there are potentially outstanding - // requests on the stream to prevent nil pointer dereferences in recv(). - defer c.releaseStream(call) +// ObservedStream observes a single request/response stream. +type ObservedStream struct { + // Host of the connection used to send the stream. + Host *HostInfo +} - if v := framer.header.version.version(); v != c.version { - return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version) - } +// StreamObserver is notified about request/response pairs. +// Streams are created for executing queries/batches or +// internal requests to the database and might live longer than +// execution of the query - the stream is still tracked until +// response arrives so that stream IDs are not reused. +type StreamObserver interface { + // StreamContext is called before creating a new stream. + // ctx is context passed to Session.Query / Session.Batch, + // but might also be an internal context (for example + // for internal requests that use control connection). + // StreamContext might return nil if it is not interested + // in the details of this stream. + // StreamContext is called before the stream is created + // and the returned StreamObserverContext might be discarded + // without any methods called on the StreamObserverContext if + // creation of the stream fails. + // Note that if you don't need to track per-stream data, + // you can always return the same StreamObserverContext. + StreamContext(ctx context.Context) StreamObserverContext +} - return framer, nil +// StreamObserverContext is notified about state of a stream. +// A stream is started every time a request is written to the server +// and is finished when a response is received. +// It is abandoned when the underlying network connection is closed +// before receiving a response. +type StreamObserverContext interface { + // StreamStarted is called when the stream is started. + // This happens just before a request is written to the wire. + StreamStarted(observedStream ObservedStream) + + // StreamAbandoned is called when we stop waiting for response. + // This happens when the underlying network connection is closed. + // StreamFinished won't be called if StreamAbandoned is. + StreamAbandoned(observedStream ObservedStream) + + // StreamFinished is called when we receive a response for the stream. + StreamFinished(observedStream ObservedStream) } type preparedStatment struct { @@ -1137,7 +1381,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { } var ( - frame frameWriter + frame frameBuilder info *preparedStatment ) @@ -1185,9 +1429,11 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { customPayload: qry.customPayload, } - // Set "lwt" property in the query if it is present in preparedMetadata + // Set "lwt", keyspace", "table" property in the query if it is present in preparedMetadata qry.routingInfo.mu.Lock() qry.routingInfo.lwt = info.request.lwt + qry.routingInfo.keyspace = info.request.keyspace + qry.routingInfo.table = info.request.table qry.routingInfo.mu.Unlock() } else { frame = &writeQueryFrame{ @@ -1283,7 +1529,9 @@ func (c *Conn) Pick(qry *Query) *Conn { } func (c *Conn) Closed() bool { - return atomic.LoadInt32(&c.closed) == 1 + c.mu.Lock() + defer c.mu.Unlock() + return c.closed } func (c *Conn) Address() string { @@ -1447,18 +1695,51 @@ func (c *Conn) query(ctx context.Context, statement string, values ...interface{ return c.executeQuery(ctx, q) } -func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) { +func (c *Conn) querySystemPeers(ctx context.Context, version cassVersion) *Iter { const ( - peerSchemas = "SELECT * FROM system.peers" - localSchemas = "SELECT schema_version FROM system.local WHERE key='local'" + peerSchema = "SELECT * FROM system.peers" + peerV2Schemas = "SELECT * FROM system.peers_v2" ) + c.mu.Lock() + isSchemaV2 := c.isSchemaV2 + c.mu.Unlock() + + if version.AtLeast(4, 0, 0) && isSchemaV2 { + // Try "system.peers_v2" and fallback to "system.peers" if it's not found + iter := c.query(ctx, peerV2Schemas) + + err := iter.checkErrAndNotFound() + if err != nil { + if errFrame, ok := err.(errorFrame); ok && errFrame.code == ErrCodeInvalid { // system.peers_v2 not found, try system.peers + c.mu.Lock() + c.isSchemaV2 = false + c.mu.Unlock() + return c.query(ctx, peerSchema) + } else { + return iter + } + } + return iter + } else { + return c.query(ctx, peerSchema) + } +} + +func (c *Conn) querySystemLocal(ctx context.Context) *Iter { + return c.query(ctx, "SELECT * FROM system.local WHERE key='local'") +} + +func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) { + const localSchemas = "SELECT schema_version FROM system.local WHERE key='local'" + var versions map[string]struct{} var schemaVersion string endDeadline := time.Now().Add(c.session.cfg.MaxWaitSchemaAgreement) + for time.Now().Before(endDeadline) { - iter := c.query(ctx, peerSchemas) + iter := c.querySystemPeers(ctx, c.host.version) versions = make(map[string]struct{}) @@ -1519,22 +1800,6 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) { return fmt.Errorf("gocql: cluster schema versions not consistent: %+v", schemas) } -func (c *Conn) localHostInfo(ctx context.Context) (*HostInfo, error) { - row, err := c.query(ctx, "SELECT * FROM system.local WHERE key='local'").rowMap() - if err != nil { - return nil, err - } - - port := c.conn.RemoteAddr().(*net.TCPAddr).Port - // TODO(zariel): avoid doing this here - host, err := c.session.hostInfoFromMap(row, &HostInfo{hostname: c.host.connectAddress.String(), connectAddress: c.host.connectAddress, port: port}) - if err != nil { - return nil, err - } - - return c.session.ring.addOrUpdate(host), nil -} - var ( ErrQueryArgLength = errors.New("gocql: query argument length mismatch") ErrTimeoutNoResponse = errors.New("gocql: no response received from cassandra within timeout period") diff --git a/vendor/github.com/gocql/gocql/connectionpool.go b/vendor/github.com/gocql/gocql/connectionpool.go index a11434b27c5..d207fa0aaca 100644 --- a/vendor/github.com/gocql/gocql/connectionpool.go +++ b/vendor/github.com/gocql/gocql/connectionpool.go @@ -126,6 +126,7 @@ func connConfig(cfg *ClusterConfig) (*ConnConfig, error) { ProtoVersion: cfg.ProtoVersion, CQLVersion: cfg.CQLVersion, Timeout: cfg.Timeout, + WriteTimeout: cfg.WriteTimeout, ConnectTimeout: cfg.ConnectTimeout, Dialer: cfg.Dialer, HostDialer: hostDialer, @@ -332,7 +333,7 @@ func (pool *hostConnPool) Pick(token token) *Conn { return pool.connPicker.Pick(token) } -//Size returns the number of connections currently active in the pool +// Size returns the number of connections currently active in the pool func (pool *hostConnPool) Size() int { pool.mu.RLock() defer pool.mu.RUnlock() @@ -341,7 +342,7 @@ func (pool *hostConnPool) Size() int { return size } -//Close the connection pool +// Close the connection pool func (pool *hostConnPool) Close() { pool.mu.Lock() defer pool.mu.Unlock() @@ -398,13 +399,7 @@ func (pool *hostConnPool) fill() { if err != nil { // probably unreachable host - pool.fillingStopped(true) - - // this is call with the connection pool mutex held, this call will - // then recursively try to lock it again. FIXME - if pool.session.cfg.ConvictionPolicy.AddFailure(err, pool.host) { - go pool.session.handleNodeDown(pool.host.ConnectAddress(), pool.host.Port()) - } + pool.fillingStopped(err) return } // notify the session that this node is connected @@ -421,7 +416,7 @@ func (pool *hostConnPool) fill() { err := pool.connectMany(fillCount) // mark the end of filling - pool.fillingStopped(err != nil) + pool.fillingStopped(err) if err == nil && startCount > 0 { // notify the session that this node is connected again @@ -444,8 +439,11 @@ func (pool *hostConnPool) logConnectErr(err error) { } // transition back to a not-filling state. -func (pool *hostConnPool) fillingStopped(hadError bool) { - if hadError { +func (pool *hostConnPool) fillingStopped(err error) { + if err != nil { + if gocqlDebug { + pool.logger.Printf("gocql: filling stopped %q: %v\n", pool.host.ConnectAddress(), err) + } // wait for some time to avoid back-to-back filling // this provides some time between failed attempts // to fill the pool for the host to recover @@ -454,7 +452,21 @@ func (pool *hostConnPool) fillingStopped(hadError bool) { pool.mu.Lock() pool.filling = false + count, _ := pool.connPicker.Size() + host := pool.host + port := pool.host.Port() pool.mu.Unlock() + + // if we errored and the size is now zero, make sure the host is marked as down + // see https://github.com/gocql/gocql/issues/1614 + if gocqlDebug { + pool.logger.Printf("gocql: conns of pool after stopped %q: %v\n", host.ConnectAddress(), count) + } + if err != nil && count == 0 { + if pool.session.cfg.ConvictionPolicy.AddFailure(err, host) { + pool.session.handleNodeDown(host.ConnectAddress(), port) + } + } } // connectMany creates new connections concurrent. @@ -510,7 +522,7 @@ func (pool *hostConnPool) connect() (err error) { } } if gocqlDebug { - pool.logger.Printf("connection failed %q: %v, reconnecting with %T\n", + pool.logger.Printf("gocql: connection failed %q: %v, reconnecting with %T\n", pool.host.ConnectAddress(), err, reconnectionPolicy) } time.Sleep(reconnectionPolicy.GetInterval(i)) @@ -574,5 +586,9 @@ func (pool *hostConnPool) HandleError(conn *Conn, err error, closed bool) { return } + if gocqlDebug { + pool.logger.Printf("gocql: pool connection error %q: %v\n", conn.addr, err) + } + pool.connPicker.Remove(conn) } diff --git a/vendor/github.com/gocql/gocql/control.go b/vendor/github.com/gocql/gocql/control.go index 0501366376a..fe2ce06e38f 100644 --- a/vendor/github.com/gocql/gocql/control.go +++ b/vendor/github.com/gocql/gocql/control.go @@ -98,7 +98,7 @@ func (c *controlConn) heartBeat() { reconn: // try to connect a bit faster sleepTime = 1 * time.Second - c.reconnect(true) + c.reconnect() continue } } @@ -167,28 +167,6 @@ func shuffleHosts(hosts []*HostInfo) []*HostInfo { return shuffled } -func (c *controlConn) shuffleDial(endpoints []*HostInfo) (*Conn, error) { - // shuffle endpoints so not all drivers will connect to the same initial - // node. - shuffled := shuffleHosts(endpoints) - - cfg := *c.session.connCfg - cfg.disableCoalesce = true - - var err error - for _, host := range shuffled { - var conn *Conn - conn, err = c.session.dial(c.session.ctx, host, &cfg, c) - if err == nil { - return conn, nil - } - - c.session.logger.Printf("gocql: unable to dial control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err) - } - - return nil, err -} - // this is going to be version dependant and a nightmare to maintain :( var protocolSupportRe = regexp.MustCompile(`the lowest supported version is \d+ and the greatest is (\d+)$`) @@ -249,14 +227,31 @@ func (c *controlConn) connect(hosts []*HostInfo) error { return errors.New("control: no endpoints specified") } - conn, err := c.shuffleDial(hosts) - if err != nil { - return fmt.Errorf("control: unable to connect to initial hosts: %v", err) - } + // shuffle endpoints so not all drivers will connect to the same initial + // node. + hosts = shuffleHosts(hosts) - if err := c.setupConn(conn); err != nil { + cfg := *c.session.connCfg + cfg.disableCoalesce = true + + var conn *Conn + var err error + for _, host := range hosts { + conn, err = c.session.dial(c.session.ctx, host, &cfg, c) + if err != nil { + c.session.logger.Printf("gocql: unable to dial control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err) + continue + } + err = c.setupConn(conn) + if err == nil { + break + } + c.session.logger.Printf("gocql: unable setup control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err) conn.Close() - return fmt.Errorf("control: unable to setup connection: %v", err) + conn = nil + } + if conn == nil { + return fmt.Errorf("unable to connect to initial hosts: %v", err) } // we could fetch the initial ring here and update initial host data. So that @@ -273,16 +268,21 @@ type connHost struct { } func (c *controlConn) setupConn(conn *Conn) error { - if err := c.registerEvents(conn); err != nil { - conn.Close() + // we need up-to-date host info for the filterHost call below + iter := conn.querySystemLocal(context.TODO()) + host, err := c.session.hostInfoFromIter(iter, conn.host.connectAddress, conn.conn.RemoteAddr().(*net.TCPAddr).Port) + if err != nil { return err } - // TODO(zariel): do we need to fetch host info everytime - // the control conn connects? Surely we have it cached? - host, err := conn.localHostInfo(context.TODO()) - if err != nil { - return err + host = c.session.ring.addOrUpdate(host) + + if c.session.cfg.filterHost(host) { + return fmt.Errorf("host was filtered: %v", host.ConnectAddress()) + } + + if err := c.registerEvents(conn); err != nil { + return fmt.Errorf("register events: %v", err) } ch := &connHost{ @@ -338,7 +338,7 @@ func (c *controlConn) registerEvents(conn *Conn) error { return nil } -func (c *controlConn) reconnect(refreshring bool) { +func (c *controlConn) reconnect() { if atomic.LoadInt32(&c.state) == controlConnClosing { return } @@ -346,57 +346,76 @@ func (c *controlConn) reconnect(refreshring bool) { return } defer atomic.StoreInt32(&c.reconnecting, 0) - // TODO: simplify this function, use session.ring to get hosts instead of the - // connection pool - var host *HostInfo - ch := c.getConn() - if ch != nil { - host = ch.host - ch.conn.Close() + conn, err := c.attemptReconnect() + + if conn == nil { + c.session.logger.Printf("gocql: unable to reconnect control connection: %v\n", err) + return } - var newConn *Conn - if host != nil { - // try to connect to the old host - conn, err := c.session.connect(c.session.ctx, host, c) - if err != nil { - // host is dead - // TODO: this is replicated in a few places - if c.session.cfg.ConvictionPolicy.AddFailure(err, host) { - c.session.handleNodeDown(host.ConnectAddress(), host.Port()) + err = c.session.refreshRing() + if err != nil { + c.session.logger.Printf("gocql: unable to refresh ring: %v\n", err) + } +} + +func (c *controlConn) attemptReconnect() (*Conn, error) { + hosts := c.session.ring.allHosts() + hosts = shuffleHosts(hosts) + + // keep the old behavior of connecting to the old host first by moving it to + // the front of the slice + ch := c.getConn() + if ch != nil { + for i := range hosts { + if hosts[i].Equal(ch.host) { + hosts[0], hosts[i] = hosts[i], hosts[0] + break } - } else { - newConn = conn } + ch.conn.Close() } - // TODO: should have our own round-robin for hosts so that we can try each - // in succession and guarantee that we get a different host each time. - if newConn == nil { - host := c.session.ring.rrHost() - if host == nil { - c.connect(c.session.ring.endpoints) - return - } + conn, err := c.attemptReconnectToAnyOfHosts(hosts) - var err error - newConn, err = c.session.connect(c.session.ctx, host, c) - if err != nil { - // TODO: add log handler for things like this - return - } + if conn != nil { + return conn, err } - if err := c.setupConn(newConn); err != nil { - newConn.Close() - c.session.logger.Printf("gocql: control unable to register events: %v\n", err) - return + c.session.logger.Printf("gocql: unable to connect to any ring node: %v\n", err) + c.session.logger.Printf("gocql: control falling back to initial contact points.\n") + // Fallback to initial contact points, as it may be the case that all known initialHosts + // changed their IPs while keeping the same hostname(s). + initialHosts, resolvErr := addrsToHosts(c.session.cfg.Hosts, c.session.cfg.Port, c.session.logger) + if resolvErr != nil { + return nil, fmt.Errorf("resolve contact points' hostnames: %v", resolvErr) } - if refreshring { - c.session.hostSource.refreshRing() + return c.attemptReconnectToAnyOfHosts(initialHosts) +} + +func (c *controlConn) attemptReconnectToAnyOfHosts(hosts []*HostInfo) (*Conn, error) { + var conn *Conn + var err error + for _, host := range hosts { + conn, err = c.session.connect(c.session.ctx, host, c) + if err != nil { + if c.session.cfg.ConvictionPolicy.AddFailure(err, host) { + c.session.handleNodeDown(host.ConnectAddress(), host.Port()) + } + c.session.logger.Printf("gocql: unable to dial control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err) + continue + } + err = c.setupConn(conn) + if err == nil { + break + } + c.session.logger.Printf("gocql: unable setup control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err) + conn.Close() + conn = nil } + return conn, err } func (c *controlConn) HandleError(conn *Conn, err error, closed bool) { @@ -412,14 +431,14 @@ func (c *controlConn) HandleError(conn *Conn, err error, closed bool) { return } - c.reconnect(false) + c.reconnect() } func (c *controlConn) getConn() *connHost { return c.conn.Load().(*connHost) } -func (c *controlConn) writeFrame(w frameWriter) (frame, error) { +func (c *controlConn) writeFrame(w frameBuilder) (frame, error) { ch := c.getConn() if ch == nil { return nil, errNoControl @@ -446,7 +465,7 @@ func (c *controlConn) withConnHost(fn func(*connHost) *Iter) *Iter { connectAttempts++ - c.reconnect(false) + c.reconnect() continue } diff --git a/vendor/github.com/gocql/gocql/debug_off.go b/vendor/github.com/gocql/gocql/debug_off.go index 3af3ae0f3eb..31e622599d4 100644 --- a/vendor/github.com/gocql/gocql/debug_off.go +++ b/vendor/github.com/gocql/gocql/debug_off.go @@ -1,3 +1,4 @@ +//go:build !gocql_debug // +build !gocql_debug package gocql diff --git a/vendor/github.com/gocql/gocql/debug_on.go b/vendor/github.com/gocql/gocql/debug_on.go index e94a00ce5b7..b3bdfab8db6 100644 --- a/vendor/github.com/gocql/gocql/debug_on.go +++ b/vendor/github.com/gocql/gocql/debug_on.go @@ -1,3 +1,4 @@ +//go:build gocql_debug // +build gocql_debug package gocql diff --git a/vendor/github.com/gocql/gocql/dial.go b/vendor/github.com/gocql/gocql/dial.go index 71c0611bc20..0613cebe01b 100644 --- a/vendor/github.com/gocql/gocql/dial.go +++ b/vendor/github.com/gocql/gocql/dial.go @@ -45,11 +45,12 @@ func (hd *defaultHostDialer) DialHost(ctx context.Context, host *HostInfo) (*Dia return nil, fmt.Errorf("host missing port: %v", port) } - addr := host.HostnameAndPort() - conn, err := hd.dialer.DialContext(ctx, "tcp", addr) + connAddr := host.ConnectAddressAndPort() + conn, err := hd.dialer.DialContext(ctx, "tcp", connAddr) if err != nil { return nil, err } + addr := host.HostnameAndPort() return WrapTLS(ctx, conn, addr, hd.tlsConfig) } diff --git a/vendor/github.com/gocql/gocql/doc.go b/vendor/github.com/gocql/gocql/doc.go index af8f4c86ca5..6739d98e4ab 100644 --- a/vendor/github.com/gocql/gocql/doc.go +++ b/vendor/github.com/gocql/gocql/doc.go @@ -5,15 +5,15 @@ // Package gocql implements a fast and robust Cassandra driver for the // Go programming language. // -// Connecting to the cluster +// # Connecting to the cluster // // Pass a list of initial node IP addresses to NewCluster to create a new cluster configuration: // -// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3") +// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3") // // Port can be specified as part of the address, the above is equivalent to: // -// cluster := gocql.NewCluster("192.168.1.1:9042", "192.168.1.2:9042", "192.168.1.3:9042") +// cluster := gocql.NewCluster("192.168.1.1:9042", "192.168.1.2:9042", "192.168.1.3:9042") // // It is recommended to use the value set in the Cassandra config for broadcast_address or listen_address, // an IP address not a domain name. This is because events from Cassandra will use the configured IP @@ -22,23 +22,26 @@ // // Then you can customize more options (see ClusterConfig): // -// cluster.Keyspace = "example" -// cluster.Consistency = gocql.Quorum -// cluster.ProtoVersion = 4 +// cluster.Keyspace = "example" +// cluster.Consistency = gocql.Quorum +// cluster.ProtoVersion = 4 // // The driver tries to automatically detect the protocol version to use if not set, but you might want to set the // protocol version explicitly, as it's not defined which version will be used in certain situations (for example // during upgrade of the cluster when some of the nodes support different set of protocol versions than other nodes). // +// The driver advertises the module name and version in the STARTUP message, so servers are able to detect the version. +// If you use replace directive in go.mod, the driver will send information about the replacement module instead. +// // When ready, create a session from the configuration. Don't forget to Close the session once you are done with it: // -// session, err := cluster.CreateSession() -// if err != nil { -// return err -// } -// defer session.Close() +// session, err := cluster.CreateSession() +// if err != nil { +// return err +// } +// defer session.Close() // -// Authentication +// # Authentication // // CQL protocol uses a SASL-based authentication mechanism and so consists of an exchange of server challenges and // client response pairs. The details of the exchanged messages depend on the authenticator used. @@ -47,18 +50,18 @@ // // PasswordAuthenticator is provided to use for username/password authentication: // -// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3") -// cluster.Authenticator = gocql.PasswordAuthenticator{ -// Username: "user", -// Password: "password" -// } -// session, err := cluster.CreateSession() -// if err != nil { -// return err -// } -// defer session.Close() +// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3") +// cluster.Authenticator = gocql.PasswordAuthenticator{ +// Username: "user", +// Password: "password" +// } +// session, err := cluster.CreateSession() +// if err != nil { +// return err +// } +// defer session.Close() // -// Transport layer security +// # Transport layer security // // It is possible to secure traffic between the client and server with TLS. // @@ -69,38 +72,38 @@ // to true if no Config is set. Most users should set SslOptions.Config to a *tls.Config. // SslOptions and Config.InsecureSkipVerify interact as follows: // -// Config.InsecureSkipVerify | EnableHostVerification | Result -// Config is nil | false | do not verify host -// Config is nil | true | verify host -// false | false | verify host -// true | false | do not verify host -// false | true | verify host -// true | true | verify host +// Config.InsecureSkipVerify | EnableHostVerification | Result +// Config is nil | false | do not verify host +// Config is nil | true | verify host +// false | false | verify host +// true | false | do not verify host +// false | true | verify host +// true | true | verify host // // For example: // -// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3") -// cluster.SslOpts = &gocql.SslOptions{ -// EnableHostVerification: true, -// } -// session, err := cluster.CreateSession() -// if err != nil { -// return err -// } -// defer session.Close() +// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3") +// cluster.SslOpts = &gocql.SslOptions{ +// EnableHostVerification: true, +// } +// session, err := cluster.CreateSession() +// if err != nil { +// return err +// } +// defer session.Close() // -// Data-center awareness and query routing +// # Data-center awareness and query routing // // To route queries to local DC first, use DCAwareRoundRobinPolicy. For example, if the datacenter you // want to primarily connect is called dc1 (as configured in the database): // -// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3") -// cluster.PoolConfig.HostSelectionPolicy = gocql.DCAwareRoundRobinPolicy("dc1") +// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3") +// cluster.PoolConfig.HostSelectionPolicy = gocql.DCAwareRoundRobinPolicy("dc1") // // The driver can route queries to nodes that hold data replicas based on partition key (preferring local DC). // -// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3") -// cluster.PoolConfig.HostSelectionPolicy = gocql.TokenAwareHostPolicy(gocql.DCAwareRoundRobinPolicy("dc1")) +// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3") +// cluster.PoolConfig.HostSelectionPolicy = gocql.TokenAwareHostPolicy(gocql.DCAwareRoundRobinPolicy("dc1")) // // Note that TokenAwareHostPolicy can take options such as gocql.ShuffleReplicas and gocql.NonLocalReplicasFallback. // @@ -109,50 +112,59 @@ // The driver can only use token-aware routing for queries where all partition key columns are query parameters. // For example, instead of // -// session.Query("select value from mytable where pk1 = 'abc' AND pk2 = ?", "def") +// session.Query("select value from mytable where pk1 = 'abc' AND pk2 = ?", "def") // // use // -// session.Query("select value from mytable where pk1 = ? AND pk2 = ?", "abc", "def") +// session.Query("select value from mytable where pk1 = ? AND pk2 = ?", "abc", "def") +// +// # Rack-level awareness +// +// The DCAwareRoundRobinPolicy can be replaced with RackAwareRoundRobinPolicy, which takes two parameters, datacenter and rack. +// +// Instead of dividing hosts with two tiers (local datacenter and remote datacenters) it divides hosts into three +// (the local rack, the rest of the local datacenter, and everything else). +// +// RackAwareRoundRobinPolicy can be combined with TokenAwareHostPolicy in the same way as DCAwareRoundRobinPolicy. // -// Executing queries +// # Executing queries // // Create queries with Session.Query. Query values must not be reused between different executions and must not be // modified after starting execution of the query. // // To execute a query without reading results, use Query.Exec: // -// err := session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`, -// "me", gocql.TimeUUID(), "hello world").WithContext(ctx).Exec() +// err := session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`, +// "me", gocql.TimeUUID(), "hello world").WithContext(ctx).Exec() // // Single row can be read by calling Query.Scan: // -// err := session.Query(`SELECT id, text FROM tweet WHERE timeline = ? LIMIT 1`, -// "me").WithContext(ctx).Consistency(gocql.One).Scan(&id, &text) +// err := session.Query(`SELECT id, text FROM tweet WHERE timeline = ? LIMIT 1`, +// "me").WithContext(ctx).Consistency(gocql.One).Scan(&id, &text) // // Multiple rows can be read using Iter.Scanner: // -// scanner := session.Query(`SELECT id, text FROM tweet WHERE timeline = ?`, -// "me").WithContext(ctx).Iter().Scanner() -// for scanner.Next() { -// var ( -// id gocql.UUID -// text string -// ) -// err = scanner.Scan(&id, &text) -// if err != nil { -// log.Fatal(err) -// } -// fmt.Println("Tweet:", id, text) -// } -// // scanner.Err() closes the iterator, so scanner nor iter should be used afterwards. -// if err := scanner.Err(); err != nil { -// log.Fatal(err) -// } +// scanner := session.Query(`SELECT id, text FROM tweet WHERE timeline = ?`, +// "me").WithContext(ctx).Iter().Scanner() +// for scanner.Next() { +// var ( +// id gocql.UUID +// text string +// ) +// err = scanner.Scan(&id, &text) +// if err != nil { +// log.Fatal(err) +// } +// fmt.Println("Tweet:", id, text) +// } +// // scanner.Err() closes the iterator, so scanner nor iter should be used afterwards. +// if err := scanner.Err(); err != nil { +// log.Fatal(err) +// } // // See Example for complete example. // -// Prepared statements +// # Prepared statements // // The driver automatically prepares DML queries (SELECT/INSERT/UPDATE/DELETE/BATCH statements) and maintains a cache // of prepared statements. @@ -163,71 +175,71 @@ // The main advantage is the ability to keep the same prepared statement even when you don't // want to update some fields, where before you needed to make another prepared statement. // -// Executing multiple queries concurrently +// # Executing multiple queries concurrently // // Session is safe to use from multiple goroutines, so to execute multiple concurrent queries, just execute them // from several worker goroutines. Gocql provides synchronously-looking API (as recommended for Go APIs) and the queries // are executed asynchronously at the protocol level. // -// results := make(chan error, 2) -// go func() { -// results <- session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`, -// "me", gocql.TimeUUID(), "hello world 1").Exec() -// }() -// go func() { -// results <- session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`, -// "me", gocql.TimeUUID(), "hello world 2").Exec() -// }() +// results := make(chan error, 2) +// go func() { +// results <- session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`, +// "me", gocql.TimeUUID(), "hello world 1").Exec() +// }() +// go func() { +// results <- session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`, +// "me", gocql.TimeUUID(), "hello world 2").Exec() +// }() // -// Nulls +// # Nulls // // Null values are are unmarshalled as zero value of the type. If you need to distinguish for example between text // column being null and empty string, you can unmarshal into *string variable instead of string. // -// var text *string -// err := scanner.Scan(&text) -// if err != nil { -// // handle error -// } -// if text != nil { -// // not null -// } -// else { -// // null -// } +// var text *string +// err := scanner.Scan(&text) +// if err != nil { +// // handle error +// } +// if text != nil { +// // not null +// } +// else { +// // null +// } // // See Example_nulls for full example. // -// Reusing slices +// # Reusing slices // // The driver reuses backing memory of slices when unmarshalling. This is an optimization so that a buffer does not // need to be allocated for every processed row. However, you need to be careful when storing the slices to other // memory structures. // -// scanner := session.Query(`SELECT myints FROM table WHERE pk = ?`, "key").WithContext(ctx).Iter().Scanner() -// var myInts []int -// for scanner.Next() { -// // This scan reuses backing store of myInts for each row. -// err = scanner.Scan(&myInts) -// if err != nil { -// log.Fatal(err) -// } -// } +// scanner := session.Query(`SELECT myints FROM table WHERE pk = ?`, "key").WithContext(ctx).Iter().Scanner() +// var myInts []int +// for scanner.Next() { +// // This scan reuses backing store of myInts for each row. +// err = scanner.Scan(&myInts) +// if err != nil { +// log.Fatal(err) +// } +// } // // When you want to save the data for later use, pass a new slice every time. A common pattern is to declare the // slice variable within the scanner loop: // -// scanner := session.Query(`SELECT myints FROM table WHERE pk = ?`, "key").WithContext(ctx).Iter().Scanner() -// for scanner.Next() { -// var myInts []int -// // This scan always gets pointer to fresh myInts slice, so does not reuse memory. -// err = scanner.Scan(&myInts) -// if err != nil { -// log.Fatal(err) -// } -// } +// scanner := session.Query(`SELECT myints FROM table WHERE pk = ?`, "key").WithContext(ctx).Iter().Scanner() +// for scanner.Next() { +// var myInts []int +// // This scan always gets pointer to fresh myInts slice, so does not reuse memory. +// err = scanner.Scan(&myInts) +// if err != nil { +// log.Fatal(err) +// } +// } // -// Paging +// # Paging // // The driver supports paging of results with automatic prefetch, see ClusterConfig.PageSize, Session.SetPrefetch, // Query.PageSize, and Query.Prefetch. @@ -258,14 +270,14 @@ // // See Example_paging for an example of manual paging. // -// Dynamic list of columns +// # Dynamic list of columns // // There are certain situations when you don't know the list of columns in advance, mainly when the query is supplied // by the user. Iter.Columns, Iter.RowData, Iter.MapScan and Iter.SliceMap can be used to handle this case. // // See Example_dynamicColumns. // -// Batches +// # Batches // // The CQL protocol supports sending batches of DML statements (INSERT/UPDATE/DELETE) and so does gocql. // Use Session.NewBatch to create a new batch and then fill-in details of individual queries. @@ -292,7 +304,7 @@ // // See Example_batch for an example. // -// Lightweight transactions +// # Lightweight transactions // // Query.ScanCAS or Query.MapScanCAS can be used to execute a single-statement lightweight transaction (an // INSERT/UPDATE .. IF statement) and reading its result. See example for Query.MapScanCAS. @@ -302,7 +314,7 @@ // Session.MapExecuteBatchCAS when executing the batch to learn about the result of the LWT. See example for // Session.MapExecuteBatchCAS. // -// Retries and speculative execution +// # Retries and speculative execution // // Queries can be marked as idempotent. Marking the query as idempotent tells the driver that the query can be executed // multiple times without affecting its result. Non-idempotent queries are not eligible for retrying nor speculative @@ -316,21 +328,21 @@ // is still executing. The two parallel executions of the query race to return a result, the first received result will // be returned. // -// User-defined types +// # User-defined types // // UDTs can be mapped (un)marshaled from/to map[string]interface{} a Go struct (or a type implementing // UDTUnmarshaler, UDTMarshaler, Unmarshaler or Marshaler interfaces). // // For structs, cql tag can be used to specify the CQL field name to be mapped to a struct field: // -// type MyUDT struct { -// FieldA int32 `cql:"a"` -// FieldB string `cql:"b"` -// } +// type MyUDT struct { +// FieldA int32 `cql:"a"` +// FieldB string `cql:"b"` +// } // // See Example_userDefinedTypesMap, Example_userDefinedTypesStruct, ExampleUDTMarshaler, ExampleUDTUnmarshaler. // -// Metrics and tracing +// # Metrics and tracing // // It is possible to provide observer implementations that could be used to gather metrics: // diff --git a/vendor/github.com/gocql/gocql/docker-compose.yml b/vendor/github.com/gocql/gocql/docker-compose.yml index 0f07930e3b1..9e7490c7d5e 100644 --- a/vendor/github.com/gocql/gocql/docker-compose.yml +++ b/vendor/github.com/gocql/gocql/docker-compose.yml @@ -4,7 +4,13 @@ services: node_1: image: ${SCYLLA_IMAGE} privileged: true - command: --smp 2 --memory 512M --seeds 192.168.100.11 --overprovisioned 1 + command: | + --smp 2 + --memory 768M + --seeds 192.168.100.11 + --overprovisioned 1 + --experimental-features udf + --enable-user-defined-functions true networks: public: ipv4_address: 192.168.100.11 @@ -21,6 +27,11 @@ services: - type: bind source: ./testdata/pki/cassandra.key target: /etc/scylla/db.key + healthcheck: + test: [ "CMD", "cqlsh", "-e", "select * from system.local" ] + interval: 5s + timeout: 5s + retries: 18 networks: public: driver: bridge diff --git a/vendor/github.com/gocql/gocql/errors.go b/vendor/github.com/gocql/gocql/errors.go index faa6f7c9dd9..4fb37268395 100644 --- a/vendor/github.com/gocql/gocql/errors.go +++ b/vendor/github.com/gocql/gocql/errors.go @@ -196,3 +196,28 @@ type RequestErrCASWriteUnknown struct { Received int BlockFor int } + +type OpType uint8 + +const ( + OpTypeRead OpType = 0 + OpTypeWrite OpType = 1 +) + +type RequestErrRateLimitReached struct { + errorFrame + OpType OpType + RejectedByCoordinator bool +} + +func (e *RequestErrRateLimitReached) String() string { + var opType string + if e.OpType == OpTypeRead { + opType = "Read" + } else if e.OpType == OpTypeWrite { + opType = "Write" + } else { + opType = "Other" + } + return fmt.Sprintf("[request_error_rate_limit_reached OpType=%s RejectedByCoordinator=%t]", opType, e.RejectedByCoordinator) +} diff --git a/vendor/github.com/gocql/gocql/events.go b/vendor/github.com/gocql/gocql/events.go index 395f156e07c..73461f629e4 100644 --- a/vendor/github.com/gocql/gocql/events.go +++ b/vendor/github.com/gocql/gocql/events.go @@ -129,6 +129,16 @@ func (s *Session) handleKeyspaceChange(keyspace, change string) { s.policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: keyspace, Change: change}) } +// handleNodeEvent handles inbound status and topology change events. +// +// Status events are debounced by host IP; only the latest event is processed. +// +// Topology events are debounced by performing a single full topology refresh +// whenever any topology event comes in. +// +// Processing topology change events before status change events ensures +// that a NEW_NODE event is not dropped in favor of a newer UP event (which +// would itself be dropped/ignored, as the node is not yet known). func (s *Session) handleNodeEvent(frames []frame) { type nodeEvent struct { change string @@ -136,48 +146,36 @@ func (s *Session) handleNodeEvent(frames []frame) { port int } - events := make(map[string]*nodeEvent) + topologyEventReceived := false + // status change events + sEvents := make(map[string]*nodeEvent) for _, frame := range frames { - // TODO: can we be sure the order of events in the buffer is correct? switch f := frame.(type) { case *topologyChangeEventFrame: - event, ok := events[f.host.String()] - if !ok { - event = &nodeEvent{change: f.change, host: f.host, port: f.port} - events[f.host.String()] = event - } - event.change = f.change - + topologyEventReceived = true case *statusChangeEventFrame: - event, ok := events[f.host.String()] + event, ok := sEvents[f.host.String()] if !ok { event = &nodeEvent{change: f.change, host: f.host, port: f.port} - events[f.host.String()] = event + sEvents[f.host.String()] = event } event.change = f.change } } - for _, f := range events { + if topologyEventReceived && !s.cfg.Events.DisableTopologyEvents { + s.debounceRingRefresh() + } + + for _, f := range sEvents { if gocqlDebug { - s.logger.Printf("gocql: dispatching event: %+v\n", f) + s.logger.Printf("gocql: dispatching status change event: %+v\n", f) } // ignore events we received if they were disabled // see https://github.com/gocql/gocql/issues/1591 switch f.change { - case "NEW_NODE": - if !s.cfg.Events.DisableTopologyEvents { - s.handleNewNode(f.host, f.port) - } - case "REMOVED_NODE": - if !s.cfg.Events.DisableTopologyEvents { - s.handleRemovedNode(f.host, f.port) - } - case "MOVED_NODE": - // java-driver handles this, not mentioned in the spec - // TODO(zariel): refresh token map case "UP": if !s.cfg.Events.DisableNodeStatusEvents { s.handleNodeUp(f.host, f.port) @@ -190,73 +188,6 @@ func (s *Session) handleNodeEvent(frames []frame) { } } -func (s *Session) addNewNode(hostID UUID) { - // Get host info and apply any filters to the host - hostInfo, err := s.hostSource.getHostInfo(hostID) - if err != nil { - s.logger.Printf("gocql: events: unable to fetch host info for hostID: %q: %v\n", hostID, err) - return - } else if hostInfo == nil { - // ignore if it's null because we couldn't find it - return - } - - if t := hostInfo.Version().nodeUpDelay(); t > 0 { - time.Sleep(t) - } - - // should this handle token moving? - hostInfo = s.ring.addOrUpdate(hostInfo) - - if !s.cfg.filterHost(hostInfo) { - s.startPoolFill(hostInfo) - } - - if s.control != nil && !s.cfg.IgnorePeerAddr { - // TODO(zariel): debounce ring refresh - s.hostSource.refreshRing() - } -} - -func (s *Session) handleNewNode(ip net.IP, port int) { - if gocqlDebug { - s.logger.Printf("gocql: Session.handleNewNode: %s:%d\n", ip.String(), port) - } - - host, ok := s.ring.getHostByIP(ip.String()) - if ok && host.IsUp() { - return - } - - if err := s.hostSource.refreshRing(); err != nil && gocqlDebug { - s.logger.Printf("gocql: Session.handleNewNode: failed to refresh ring: %w\n", err.Error()) - } -} - -func (s *Session) handleRemovedNode(ip net.IP, port int) { - if gocqlDebug { - s.logger.Printf("gocql: Session.handleRemovedNode: %s:%d\n", ip.String(), port) - } - - // we remove all nodes but only add ones which pass the filter - host, ok := s.ring.getHostByIP(ip.String()) - if ok { - hostID := host.HostID() - s.ring.removeHost(hostID) - - host.setState(NodeDown) - if !s.cfg.filterHost(host) { - s.policy.RemoveHost(host) - s.pool.removeHost(hostID) - } - - } - - if err := s.hostSource.refreshRing(); err != nil && gocqlDebug { - s.logger.Println("failed to refresh ring:", err) - } -} - func (s *Session) handleNodeUp(eventIp net.IP, eventPort int) { if gocqlDebug { s.logger.Printf("gocql: Session.handleNodeUp: %s:%d\n", eventIp.String(), eventPort) @@ -264,6 +195,7 @@ func (s *Session) handleNodeUp(eventIp net.IP, eventPort int) { host, ok := s.ring.getHostByIP(eventIp.String()) if !ok { + s.debounceRingRefresh() return } diff --git a/vendor/github.com/gocql/gocql/exec.go b/vendor/github.com/gocql/gocql/exec.go index a4ed3acd441..8d4d6a310af 100644 --- a/vendor/github.com/gocql/gocql/exec.go +++ b/vendor/github.com/gocql/gocql/exec.go @@ -77,12 +77,32 @@ func NewSingleHostQueryExecutor(cfg *ClusterConfig) (e SingleHostQueryExecutor, // Create control connection to one of the hosts e.control = createControlConn(e.session) + + // shuffle endpoints so not all drivers will connect to the same initial + // node. + hosts = shuffleHosts(hosts) + + conncfg := *e.control.session.connCfg + conncfg.disableCoalesce = true + var conn *Conn - if conn, err = e.control.shuffleDial(hosts); err != nil { - err = fmt.Errorf("connect: %w", err) - return + + for _, host := range hosts { + conn, err = e.control.session.dial(e.control.session.ctx, host, &conncfg, e.control) + if err != nil { + e.control.session.logger.Printf("gocql: unable to dial control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err) + continue + } + err = e.control.setupConn(conn) + if err == nil { + break + } + e.control.session.logger.Printf("gocql: unable setup control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err) + conn.Close() + conn = nil } - if err = e.control.setupConn(conn); err != nil { + + if conn == nil { err = fmt.Errorf("setup: %w", err) return } diff --git a/vendor/github.com/gocql/gocql/frame.go b/vendor/github.com/gocql/gocql/frame.go index 9316a229f13..caf00eb34e4 100644 --- a/vendor/github.com/gocql/gocql/frame.go +++ b/vendor/github.com/gocql/gocql/frame.go @@ -346,9 +346,6 @@ type FrameHeaderObserver interface { // a framer is responsible for reading, writing and parsing frames on a single stream type framer struct { - r io.Reader - w io.Writer - proto byte // flags are for outgoing flags, enabling compression and tracing etc flags byte @@ -360,22 +357,23 @@ type framer struct { // if tracing flag is set this is not nil traceID []byte - // holds a ref to the whole byte slice for rbuf so that it can be reset to + // holds a ref to the whole byte slice for buf so that it can be reset to // 0 after a read. readBuffer []byte - rbuf []byte - wbuf []byte + buf []byte customPayload map[string][]byte - flagLWT int + flagLWT int + rateLimitingErrorCode int } -func newFramer(r io.Reader, w io.Writer, compressor Compressor, version byte) *framer { +func newFramer(compressor Compressor, version byte) *framer { + buf := make([]byte, defaultBufSize) f := &framer{ - wbuf: make([]byte, defaultBufSize), - readBuffer: make([]byte, defaultBufSize), + buf: buf[:0], + readBuffer: buf, } var flags byte if compressor != nil { @@ -397,22 +395,15 @@ func newFramer(r io.Reader, w io.Writer, compressor Compressor, version byte) *f f.flags = flags f.headSize = headSize - f.r = r - f.rbuf = f.readBuffer[:0] - - f.w = w - f.wbuf = f.wbuf[:0] - f.header = nil f.traceID = nil return f } -func newFramerWithExts(r io.Reader, w io.Writer, compressor Compressor, version byte, - cqlProtoExts []cqlProtocolExtension) *framer { +func newFramerWithExts(compressor Compressor, version byte, cqlProtoExts []cqlProtocolExtension) *framer { - f := newFramer(r, w, compressor, version) + f := newFramer(compressor, version) if lwtExt := findCQLProtoExtByName(cqlProtoExts, lwtAddMetadataMarkKey); lwtExt != nil { castedExt, ok := lwtExt.(*lwtAddMetadataMarkExt) @@ -425,6 +416,17 @@ func newFramerWithExts(r io.Reader, w io.Writer, compressor Compressor, version f.flagLWT = castedExt.lwtOptMetaBitMask } + if rateLimitErrorExt := findCQLProtoExtByName(cqlProtoExts, rateLimitError); rateLimitErrorExt != nil { + castedExt, ok := rateLimitErrorExt.(*rateLimitExt) + if !ok { + Logger.Println( + fmt.Errorf("Failed to cast CQL protocol extension identified by name %s to type %T", + rateLimitError, rateLimitExt{})) + return f + } + f.rateLimitingErrorCode = castedExt.rateLimitErrorCode + } + return f } @@ -491,12 +493,12 @@ func (f *framer) payload() { } // reads a frame form the wire into the framers buffer -func (f *framer) readFrame(head *frameHeader) error { +func (f *framer) readFrame(r io.Reader, head *frameHeader) error { if head.length < 0 { return fmt.Errorf("frame body length can not be less than 0: %d", head.length) } else if head.length > maxFrameSize { // need to free up the connection to be used again - _, err := io.CopyN(ioutil.Discard, f.r, int64(head.length)) + _, err := io.CopyN(ioutil.Discard, r, int64(head.length)) if err != nil { return fmt.Errorf("error whilst trying to discard frame with invalid length: %v", err) } @@ -504,14 +506,14 @@ func (f *framer) readFrame(head *frameHeader) error { } if cap(f.readBuffer) >= head.length { - f.rbuf = f.readBuffer[:head.length] + f.buf = f.readBuffer[:head.length] } else { f.readBuffer = make([]byte, head.length) - f.rbuf = f.readBuffer + f.buf = f.readBuffer } // assume the underlying reader takes care of timeouts and retries - n, err := io.ReadFull(f.r, f.rbuf) + n, err := io.ReadFull(r, f.buf) if err != nil { return fmt.Errorf("unable to read frame body: read %d/%d bytes: %v", n, head.length, err) } @@ -521,7 +523,7 @@ func (f *framer) readFrame(head *frameHeader) error { return NewErrProtocol("no compressor available with compressed frame body") } - f.rbuf, err = f.compres.Decode(f.rbuf) + f.buf, err = f.compres.Decode(f.buf) if err != nil { return err } @@ -699,7 +701,16 @@ func (f *framer) parseErrorFrame() frame { // TODO(zariel): we should have some distinct types for these errors return errD default: - panic(fmt.Errorf("unknown error code: 0x%x", errD.code)) + if f.rateLimitingErrorCode != 0 && code == f.rateLimitingErrorCode { + res := &RequestErrRateLimitReached{ + errorFrame: errD, + } + res.OpType = OpType(f.readByte()) + res.RejectedByCoordinator = f.readByte() != 0 + return res + } else { + panic(fmt.Errorf("unknown error code: 0x%x", errD.code)) + } } } @@ -714,25 +725,25 @@ func (f *framer) readErrorMap() (errMap ErrorMap) { } func (f *framer) writeHeader(flags byte, op frameOp, stream int) { - f.wbuf = f.wbuf[:0] - f.wbuf = append(f.wbuf, + f.buf = f.buf[:0] + f.buf = append(f.buf, f.proto, flags, ) if f.proto > protoVersion2 { - f.wbuf = append(f.wbuf, + f.buf = append(f.buf, byte(stream>>8), byte(stream), ) } else { - f.wbuf = append(f.wbuf, + f.buf = append(f.buf, byte(stream), ) } // pad out length - f.wbuf = append(f.wbuf, + f.buf = append(f.buf, byte(op), 0, 0, @@ -747,43 +758,43 @@ func (f *framer) setLength(length int) { p = 5 } - f.wbuf[p+0] = byte(length >> 24) - f.wbuf[p+1] = byte(length >> 16) - f.wbuf[p+2] = byte(length >> 8) - f.wbuf[p+3] = byte(length) + f.buf[p+0] = byte(length >> 24) + f.buf[p+1] = byte(length >> 16) + f.buf[p+2] = byte(length >> 8) + f.buf[p+3] = byte(length) } -func (f *framer) finishWrite() error { - if len(f.wbuf) > maxFrameSize { +func (f *framer) finish() error { + if len(f.buf) > maxFrameSize { // huge app frame, lets remove it so it doesn't bloat the heap - f.wbuf = make([]byte, defaultBufSize) + f.buf = make([]byte, defaultBufSize) return ErrFrameTooBig } - if f.wbuf[1]&flagCompress == flagCompress { + if f.buf[1]&flagCompress == flagCompress { if f.compres == nil { panic("compress flag set with no compressor") } // TODO: only compress frames which are big enough - compressed, err := f.compres.Encode(f.wbuf[f.headSize:]) + compressed, err := f.compres.Encode(f.buf[f.headSize:]) if err != nil { return err } - f.wbuf = append(f.wbuf[:f.headSize], compressed...) + f.buf = append(f.buf[:f.headSize], compressed...) } - length := len(f.wbuf) - f.headSize + length := len(f.buf) - f.headSize f.setLength(length) - _, err := f.w.Write(f.wbuf) - if err != nil { - return err - } - return nil } +func (f *framer) writeTo(w io.Writer) error { + _, err := w.Write(f.buf) + return err +} + func (f *framer) readTrace() { f.traceID = f.readUUID().Bytes() } @@ -822,11 +833,11 @@ func (w writeStartupFrame) String() string { return fmt.Sprintf("[startup opts=%+v]", w.opts) } -func (w *writeStartupFrame) writeFrame(f *framer, streamID int) error { +func (w *writeStartupFrame) buildFrame(f *framer, streamID int) error { f.writeHeader(f.flags&^flagCompress, opStartup, streamID) f.writeStringMap(w.opts) - return f.finishWrite() + return f.finish() } type writePrepareFrame struct { @@ -835,7 +846,7 @@ type writePrepareFrame struct { customPayload map[string][]byte } -func (w *writePrepareFrame) writeFrame(f *framer, streamID int) error { +func (w *writePrepareFrame) buildFrame(f *framer, streamID int) error { if len(w.customPayload) > 0 { f.payload() } @@ -858,7 +869,7 @@ func (w *writePrepareFrame) writeFrame(f *framer, streamID int) error { f.writeString(w.keyspace) } - return f.finishWrite() + return f.finish() } func (f *framer) readTypeInfo() TypeInfo { @@ -933,6 +944,10 @@ type preparedMetadata struct { // proto v4+ pkeyColumns []int + + keyspace string + + table string } func (r preparedMetadata) String() string { @@ -970,11 +985,10 @@ func (f *framer) parsePreparedMetadata() preparedMetadata { return meta } - var keyspace, table string globalSpec := meta.flags&flagGlobalTableSpec == flagGlobalTableSpec if globalSpec { - keyspace = f.readString() - table = f.readString() + meta.keyspace = f.readString() + meta.table = f.readString() } var cols []ColumnInfo @@ -982,14 +996,14 @@ func (f *framer) parsePreparedMetadata() preparedMetadata { // preallocate columninfo to avoid excess copying cols = make([]ColumnInfo, meta.colCount) for i := 0; i < meta.colCount; i++ { - f.readCol(&cols[i], &meta.resultMetadata, globalSpec, keyspace, table) + f.readCol(&cols[i], &meta.resultMetadata, globalSpec, meta.keyspace, meta.table) } } else { // use append, huge number of columns usually indicates a corrupt frame or // just a huge row. for i := 0; i < meta.colCount; i++ { var col ColumnInfo - f.readCol(&col, &meta.resultMetadata, globalSpec, keyspace, table) + f.readCol(&col, &meta.resultMetadata, globalSpec, meta.keyspace, meta.table) cols = append(cols, col) } } @@ -1422,14 +1436,14 @@ func (a *writeAuthResponseFrame) String() string { return fmt.Sprintf("[auth_response data=%q]", a.data) } -func (a *writeAuthResponseFrame) writeFrame(framer *framer, streamID int) error { +func (a *writeAuthResponseFrame) buildFrame(framer *framer, streamID int) error { return framer.writeAuthResponseFrame(streamID, a.data) } func (f *framer) writeAuthResponseFrame(streamID int, data []byte) error { f.writeHeader(f.flags, opAuthResponse, streamID) f.writeBytes(data) - return f.finishWrite() + return f.finish() } type queryValues struct { @@ -1567,7 +1581,7 @@ func (w *writeQueryFrame) String() string { return fmt.Sprintf("[query statement=%q params=%v]", w.statement, w.params) } -func (w *writeQueryFrame) writeFrame(framer *framer, streamID int) error { +func (w *writeQueryFrame) buildFrame(framer *framer, streamID int) error { return framer.writeQueryFrame(streamID, w.statement, &w.params, w.customPayload) } @@ -1580,16 +1594,16 @@ func (f *framer) writeQueryFrame(streamID int, statement string, params *queryPa f.writeLongString(statement) f.writeQueryParams(params) - return f.finishWrite() + return f.finish() } -type frameWriter interface { - writeFrame(framer *framer, streamID int) error +type frameBuilder interface { + buildFrame(framer *framer, streamID int) error } type frameWriterFunc func(framer *framer, streamID int) error -func (f frameWriterFunc) writeFrame(framer *framer, streamID int) error { +func (f frameWriterFunc) buildFrame(framer *framer, streamID int) error { return f(framer, streamID) } @@ -1605,7 +1619,7 @@ func (e *writeExecuteFrame) String() string { return fmt.Sprintf("[execute id=% X params=%v]", e.preparedID, &e.params) } -func (e *writeExecuteFrame) writeFrame(fr *framer, streamID int) error { +func (e *writeExecuteFrame) buildFrame(fr *framer, streamID int) error { return fr.writeExecuteFrame(streamID, e.preparedID, &e.params, &e.customPayload) } @@ -1631,7 +1645,7 @@ func (f *framer) writeExecuteFrame(streamID int, preparedID []byte, params *quer f.writeConsistency(params.consistency) } - return f.finishWrite() + return f.finish() } // TODO: can we replace BatchStatemt with batchStatement? As they prety much @@ -1656,7 +1670,7 @@ type writeBatchFrame struct { customPayload map[string][]byte } -func (w *writeBatchFrame) writeFrame(framer *framer, streamID int) error { +func (w *writeBatchFrame) buildFrame(framer *framer, streamID int) error { return framer.writeBatchFrame(streamID, w, w.customPayload) } @@ -1734,25 +1748,25 @@ func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload } } - return f.finishWrite() + return f.finish() } type writeOptionsFrame struct{} -func (w *writeOptionsFrame) writeFrame(framer *framer, streamID int) error { +func (w *writeOptionsFrame) buildFrame(framer *framer, streamID int) error { return framer.writeOptionsFrame(streamID, w) } func (f *framer) writeOptionsFrame(stream int, _ *writeOptionsFrame) error { f.writeHeader(f.flags&^flagCompress, opOptions, stream) - return f.finishWrite() + return f.finish() } type writeRegisterFrame struct { events []string } -func (w *writeRegisterFrame) writeFrame(framer *framer, streamID int) error { +func (w *writeRegisterFrame) buildFrame(framer *framer, streamID int) error { return framer.writeRegisterFrame(streamID, w) } @@ -1760,70 +1774,70 @@ func (f *framer) writeRegisterFrame(streamID int, w *writeRegisterFrame) error { f.writeHeader(f.flags, opRegister, streamID) f.writeStringList(w.events) - return f.finishWrite() + return f.finish() } func (f *framer) readByte() byte { - if len(f.rbuf) < 1 { - panic(fmt.Errorf("not enough bytes in buffer to read byte require 1 got: %d", len(f.rbuf))) + if len(f.buf) < 1 { + panic(fmt.Errorf("not enough bytes in buffer to read byte require 1 got: %d", len(f.buf))) } - b := f.rbuf[0] - f.rbuf = f.rbuf[1:] + b := f.buf[0] + f.buf = f.buf[1:] return b } func (f *framer) readInt() (n int) { - if len(f.rbuf) < 4 { - panic(fmt.Errorf("not enough bytes in buffer to read int require 4 got: %d", len(f.rbuf))) + if len(f.buf) < 4 { + panic(fmt.Errorf("not enough bytes in buffer to read int require 4 got: %d", len(f.buf))) } - n = int(int32(f.rbuf[0])<<24 | int32(f.rbuf[1])<<16 | int32(f.rbuf[2])<<8 | int32(f.rbuf[3])) - f.rbuf = f.rbuf[4:] + n = int(int32(f.buf[0])<<24 | int32(f.buf[1])<<16 | int32(f.buf[2])<<8 | int32(f.buf[3])) + f.buf = f.buf[4:] return } func (f *framer) readShort() (n uint16) { - if len(f.rbuf) < 2 { - panic(fmt.Errorf("not enough bytes in buffer to read short require 2 got: %d", len(f.rbuf))) + if len(f.buf) < 2 { + panic(fmt.Errorf("not enough bytes in buffer to read short require 2 got: %d", len(f.buf))) } - n = uint16(f.rbuf[0])<<8 | uint16(f.rbuf[1]) - f.rbuf = f.rbuf[2:] + n = uint16(f.buf[0])<<8 | uint16(f.buf[1]) + f.buf = f.buf[2:] return } func (f *framer) readString() (s string) { size := f.readShort() - if len(f.rbuf) < int(size) { - panic(fmt.Errorf("not enough bytes in buffer to read string require %d got: %d", size, len(f.rbuf))) + if len(f.buf) < int(size) { + panic(fmt.Errorf("not enough bytes in buffer to read string require %d got: %d", size, len(f.buf))) } - s = string(f.rbuf[:size]) - f.rbuf = f.rbuf[size:] + s = string(f.buf[:size]) + f.buf = f.buf[size:] return } func (f *framer) readLongString() (s string) { size := f.readInt() - if len(f.rbuf) < size { - panic(fmt.Errorf("not enough bytes in buffer to read long string require %d got: %d", size, len(f.rbuf))) + if len(f.buf) < size { + panic(fmt.Errorf("not enough bytes in buffer to read long string require %d got: %d", size, len(f.buf))) } - s = string(f.rbuf[:size]) - f.rbuf = f.rbuf[size:] + s = string(f.buf[:size]) + f.buf = f.buf[size:] return } func (f *framer) readUUID() *UUID { - if len(f.rbuf) < 16 { - panic(fmt.Errorf("not enough bytes in buffer to read uuid require %d got: %d", 16, len(f.rbuf))) + if len(f.buf) < 16 { + panic(fmt.Errorf("not enough bytes in buffer to read uuid require %d got: %d", 16, len(f.buf))) } // TODO: how to handle this error, if it is a uuid, then sureley, problems? - u, _ := UUIDFromBytes(f.rbuf[:16]) - f.rbuf = f.rbuf[16:] + u, _ := UUIDFromBytes(f.buf[:16]) + f.buf = f.buf[16:] return &u } @@ -1844,12 +1858,12 @@ func (f *framer) readBytesInternal() ([]byte, error) { return nil, nil } - if len(f.rbuf) < size { - return nil, fmt.Errorf("not enough bytes in buffer to read bytes require %d got: %d", size, len(f.rbuf)) + if len(f.buf) < size { + return nil, fmt.Errorf("not enough bytes in buffer to read bytes require %d got: %d", size, len(f.buf)) } - l := f.rbuf[:size] - f.rbuf = f.rbuf[size:] + l := f.buf[:size] + f.buf = f.buf[size:] return l, nil } @@ -1865,35 +1879,35 @@ func (f *framer) readBytes() []byte { func (f *framer) readShortBytes() []byte { size := f.readShort() - if len(f.rbuf) < int(size) { - panic(fmt.Errorf("not enough bytes in buffer to read short bytes: require %d got %d", size, len(f.rbuf))) + if len(f.buf) < int(size) { + panic(fmt.Errorf("not enough bytes in buffer to read short bytes: require %d got %d", size, len(f.buf))) } - l := f.rbuf[:size] - f.rbuf = f.rbuf[size:] + l := f.buf[:size] + f.buf = f.buf[size:] return l } func (f *framer) readInetAdressOnly() net.IP { - if len(f.rbuf) < 1 { - panic(fmt.Errorf("not enough bytes in buffer to read inet size require %d got: %d", 1, len(f.rbuf))) + if len(f.buf) < 1 { + panic(fmt.Errorf("not enough bytes in buffer to read inet size require %d got: %d", 1, len(f.buf))) } - size := f.rbuf[0] - f.rbuf = f.rbuf[1:] + size := f.buf[0] + f.buf = f.buf[1:] if !(size == 4 || size == 16) { panic(fmt.Errorf("invalid IP size: %d", size)) } - if len(f.rbuf) < 1 { - panic(fmt.Errorf("not enough bytes in buffer to read inet require %d got: %d", size, len(f.rbuf))) + if len(f.buf) < 1 { + panic(fmt.Errorf("not enough bytes in buffer to read inet require %d got: %d", size, len(f.buf))) } ip := make([]byte, size) - copy(ip, f.rbuf[:size]) - f.rbuf = f.rbuf[size:] + copy(ip, f.buf[:size]) + f.buf = f.buf[size:] return net.IP(ip) } @@ -1932,7 +1946,7 @@ func (f *framer) readStringMultiMap() map[string][]string { } func (f *framer) writeByte(b byte) { - f.wbuf = append(f.wbuf, b) + f.buf = append(f.buf, b) } func appendBytes(p []byte, d []byte) []byte { @@ -1989,29 +2003,29 @@ func (f *framer) writeCustomPayload(customPayload *map[string][]byte) { // these are protocol level binary types func (f *framer) writeInt(n int32) { - f.wbuf = appendInt(f.wbuf, n) + f.buf = appendInt(f.buf, n) } func (f *framer) writeUint(n uint32) { - f.wbuf = appendUint(f.wbuf, n) + f.buf = appendUint(f.buf, n) } func (f *framer) writeShort(n uint16) { - f.wbuf = appendShort(f.wbuf, n) + f.buf = appendShort(f.buf, n) } func (f *framer) writeLong(n int64) { - f.wbuf = appendLong(f.wbuf, n) + f.buf = appendLong(f.buf, n) } func (f *framer) writeString(s string) { f.writeShort(uint16(len(s))) - f.wbuf = append(f.wbuf, s...) + f.buf = append(f.buf, s...) } func (f *framer) writeLongString(s string) { f.writeInt(int32(len(s))) - f.wbuf = append(f.wbuf, s...) + f.buf = append(f.buf, s...) } func (f *framer) writeStringList(l []string) { @@ -2037,13 +2051,13 @@ func (f *framer) writeBytes(p []byte) { f.writeInt(-1) } else { f.writeInt(int32(len(p))) - f.wbuf = append(f.wbuf, p...) + f.buf = append(f.buf, p...) } } func (f *framer) writeShortBytes(p []byte) { f.writeShort(uint16(len(p))) - f.wbuf = append(f.wbuf, p...) + f.buf = append(f.buf, p...) } func (f *framer) writeConsistency(cons Consistency) { diff --git a/vendor/github.com/gocql/gocql/fuzz.go b/vendor/github.com/gocql/gocql/fuzz.go index 3606f9381d7..0d4cff0e57f 100644 --- a/vendor/github.com/gocql/gocql/fuzz.go +++ b/vendor/github.com/gocql/gocql/fuzz.go @@ -1,3 +1,4 @@ +//go:build gofuzz // +build gofuzz package gocql diff --git a/vendor/github.com/gocql/gocql/helpers.go b/vendor/github.com/gocql/gocql/helpers.go index 142577ce3a3..00f339779f3 100644 --- a/vendor/github.com/gocql/gocql/helpers.go +++ b/vendor/github.com/gocql/gocql/helpers.go @@ -20,52 +20,64 @@ type RowData struct { Values []interface{} } -func goType(t TypeInfo) reflect.Type { +func goType(t TypeInfo) (reflect.Type, error) { switch t.Type() { case TypeVarchar, TypeAscii, TypeInet, TypeText: - return reflect.TypeOf(*new(string)) + return reflect.TypeOf(*new(string)), nil case TypeBigInt, TypeCounter: - return reflect.TypeOf(*new(int64)) + return reflect.TypeOf(*new(int64)), nil case TypeTime: - return reflect.TypeOf(*new(time.Duration)) + return reflect.TypeOf(*new(time.Duration)), nil case TypeTimestamp: - return reflect.TypeOf(*new(time.Time)) + return reflect.TypeOf(*new(time.Time)), nil case TypeBlob: - return reflect.TypeOf(*new([]byte)) + return reflect.TypeOf(*new([]byte)), nil case TypeBoolean: - return reflect.TypeOf(*new(bool)) + return reflect.TypeOf(*new(bool)), nil case TypeFloat: - return reflect.TypeOf(*new(float32)) + return reflect.TypeOf(*new(float32)), nil case TypeDouble: - return reflect.TypeOf(*new(float64)) + return reflect.TypeOf(*new(float64)), nil case TypeInt: - return reflect.TypeOf(*new(int)) + return reflect.TypeOf(*new(int)), nil case TypeSmallInt: - return reflect.TypeOf(*new(int16)) + return reflect.TypeOf(*new(int16)), nil case TypeTinyInt: - return reflect.TypeOf(*new(int8)) + return reflect.TypeOf(*new(int8)), nil case TypeDecimal: - return reflect.TypeOf(*new(*inf.Dec)) + return reflect.TypeOf(*new(*inf.Dec)), nil case TypeUUID, TypeTimeUUID: - return reflect.TypeOf(*new(UUID)) + return reflect.TypeOf(*new(UUID)), nil case TypeList, TypeSet: - return reflect.SliceOf(goType(t.(CollectionType).Elem)) + elemType, err := goType(t.(CollectionType).Elem) + if err != nil { + return nil, err + } + return reflect.SliceOf(elemType), nil case TypeMap: - return reflect.MapOf(goType(t.(CollectionType).Key), goType(t.(CollectionType).Elem)) + keyType, err := goType(t.(CollectionType).Key) + if err != nil { + return nil, err + } + valueType, err := goType(t.(CollectionType).Elem) + if err != nil { + return nil, err + } + return reflect.MapOf(keyType, valueType), nil case TypeVarint: - return reflect.TypeOf(*new(*big.Int)) + return reflect.TypeOf(*new(*big.Int)), nil case TypeTuple: // what can we do here? all there is to do is to make a list of interface{} tuple := t.(TupleTypeInfo) - return reflect.TypeOf(make([]interface{}, len(tuple.Elems))) + return reflect.TypeOf(make([]interface{}, len(tuple.Elems))), nil case TypeUDT: - return reflect.TypeOf(make(map[string]interface{})) + return reflect.TypeOf(make(map[string]interface{})), nil case TypeDate: - return reflect.TypeOf(*new(time.Time)) + return reflect.TypeOf(*new(time.Time)), nil case TypeDuration: - return reflect.TypeOf(*new(Duration)) + return reflect.TypeOf(*new(Duration)), nil default: - return nil + return nil, fmt.Errorf("cannot create Go type for unknown CQL type %s", t) } } @@ -300,13 +312,20 @@ func (iter *Iter) RowData() (RowData, error) { for _, column := range iter.Columns() { if c, ok := column.TypeInfo.(TupleTypeInfo); !ok { - val := column.TypeInfo.New() + val, err := column.TypeInfo.NewWithError() + if err != nil { + return RowData{}, err + } columns = append(columns, column.Name) values = append(values, val) } else { for i, elem := range c.Elems { columns = append(columns, TupleColumnName(column.Name, i)) - values = append(values, elem.New()) + val, err := elem.NewWithError() + if err != nil { + return RowData{}, err + } + values = append(values, val) } } } diff --git a/vendor/github.com/gocql/gocql/host_source.go b/vendor/github.com/gocql/gocql/host_source.go index 5ce4634a0e3..ae0de33b5f1 100644 --- a/vendor/github.com/gocql/gocql/host_source.go +++ b/vendor/github.com/gocql/gocql/host_source.go @@ -11,6 +11,9 @@ import ( "time" ) +var ErrCannotFindHost = errors.New("cannot find host") +var ErrHostAlreadyExists = errors.New("host already exists") + type nodeState int32 func (n nodeState) String() string { @@ -426,6 +429,13 @@ func (h *HostInfo) Hostname() string { return h.hostname } +func (h *HostInfo) ConnectAddressAndPort() string { + h.mu.Lock() + defer h.mu.Unlock() + addr, _ := h.connectAddressLocked() + return net.JoinHostPort(addr.String(), strconv.Itoa(h.port)) +} + func (h *HostInfo) String() string { h.mu.RLock() defer h.mu.RUnlock() @@ -551,12 +561,24 @@ func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (* return nil, fmt.Errorf(assertErrorMsg, "rpc_address") } host.rpcAddress = net.ParseIP(ip) + case "native_address": + ip, ok := value.(string) + if !ok { + return nil, fmt.Errorf(assertErrorMsg, "native_address") + } + host.rpcAddress = net.ParseIP(ip) case "listen_address": ip, ok := value.(string) if !ok { return nil, fmt.Errorf(assertErrorMsg, "listen_address") } host.listenAddress = net.ParseIP(ip) + case "native_port": + native_port, ok := value.(int) + if !ok { + return nil, fmt.Errorf(assertErrorMsg, "native_port") + } + host.port = native_port case "workload": host.workload, ok = value.(string) if !ok { @@ -596,12 +618,54 @@ func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (* return host, nil } +func (s *Session) hostInfoFromIter(iter *Iter, connectAddress net.IP, defaultPort int) (*HostInfo, error) { + rows, err := iter.SliceMap() + if err != nil { + // TODO(zariel): make typed error + return nil, err + } + + if len(rows) == 0 { + return nil, errors.New("query returned 0 rows") + } + + host, err := s.hostInfoFromMap(rows[0], &HostInfo{connectAddress: connectAddress, port: defaultPort}) + if err != nil { + return nil, err + } + return host, nil +} + +// Ask the control node for the local host info +func (r *ringDescriber) getLocalHostInfo() (*HostInfo, error) { + if r.session.control == nil { + return nil, errNoControl + } + + iter := r.session.control.withConnHost(func(ch *connHost) *Iter { + return ch.conn.querySystemLocal(context.TODO()) + }) + + if iter == nil { + return nil, errNoControl + } + + host, err := r.session.hostInfoFromIter(iter, nil, r.session.cfg.Port) + if err != nil { + return nil, fmt.Errorf("could not retrieve local host info: %w", err) + } + return host, nil +} + // Ask the control node for host info on all it's known peers -func (r *ringDescriber) getClusterPeerInfo() ([]*HostInfo, error) { - var hosts []*HostInfo +func (r *ringDescriber) getClusterPeerInfo(localHost *HostInfo) ([]*HostInfo, error) { + if r.session.control == nil { + return nil, errNoControl + } + + var peers []*HostInfo iter := r.session.control.withConnHost(func(ch *connHost) *Iter { - hosts = append(hosts, ch.host) - return ch.conn.query(context.TODO(), "SELECT * FROM system.peers") + return ch.conn.querySystemPeers(context.TODO(), localHost.version) }) if iter == nil { @@ -626,10 +690,10 @@ func (r *ringDescriber) getClusterPeerInfo() ([]*HostInfo, error) { continue } - hosts = append(hosts, host) + peers = append(peers, host) } - return hosts, nil + return peers, nil } // Return true if the host is a valid peer @@ -641,16 +705,22 @@ func isValidPeer(host *HostInfo) bool { len(host.tokens) == 0) } -// Return a list of hosts the cluster knows about +// GetHosts returns a list of hosts found via queries to system.local and system.peers func (r *ringDescriber) GetHosts() ([]*HostInfo, string, error) { r.mu.Lock() defer r.mu.Unlock() - hosts, err := r.getClusterPeerInfo() + localHost, err := r.getLocalHostInfo() + if err != nil { + return r.prevHosts, r.prevPartitioner, err + } + + peerHosts, err := r.getClusterPeerInfo(localHost) if err != nil { return r.prevHosts, r.prevPartitioner, err } + hosts := append([]*HostInfo{localHost}, peerHosts...) var partitioner string if len(hosts) > 0 { partitioner = hosts[0].Partitioner() @@ -669,7 +739,11 @@ func (r *ringDescriber) getHostInfo(hostID UUID) (*HostInfo, error) { return nil } - return ch.conn.query(context.TODO(), fmt.Sprintf("SELECT * FROM %s", table)) + if table == "system.peers" { + return ch.conn.querySystemPeers(context.TODO(), ch.host.version) + } else { + return ch.conn.query(context.TODO(), fmt.Sprintf("SELECT * FROM %s", table)) + } }) if iter != nil { @@ -701,11 +775,22 @@ func (r *ringDescriber) getHostInfo(hostID UUID) (*HostInfo, error) { return host, nil } -func (r *ringDescriber) refreshRing() error { - // if we have 0 hosts this will return the previous list of hosts to - // attempt to reconnect to the cluster otherwise we would never find - // downed hosts again, could possibly have an optimisation to only - // try to add new hosts if GetHosts didnt error and the hosts didnt change. +// debounceRingRefresh submits a ring refresh request to the ring refresh debouncer. +func (s *Session) debounceRingRefresh() { + s.ringRefresher.debounce() +} + +// refreshRing executes a ring refresh immediately and cancels pending debounce ring refresh requests. +func (s *Session) refreshRing() error { + err, ok := <-s.ringRefresher.refreshNow() + if !ok { + return errors.New("could not refresh ring because stop was requested") + } + + return err +} + +func refreshRing(r *ringDescriber) error { hosts, partitioner, err := r.GetHosts() if err != nil { return err @@ -713,7 +798,6 @@ func (r *ringDescriber) refreshRing() error { prevHosts := r.session.ring.currentHosts() - // TODO: move this to session for _, h := range hosts { if r.session.cfg.filterHost(h) { continue @@ -722,14 +806,29 @@ func (r *ringDescriber) refreshRing() error { if host, ok := r.session.ring.addHostIfMissing(h); !ok { r.session.startPoolFill(h) } else { - host.update(h) + // host (by hostID) already exists; determine if IP has changed + newHostID := h.HostID() + existing, ok := prevHosts[newHostID] + if !ok { + return fmt.Errorf("get existing host=%s from prevHosts: %w", h, ErrCannotFindHost) + } + if h.connectAddress.Equal(existing.connectAddress) && h.nodeToNodeAddress().Equal(existing.nodeToNodeAddress()) { + // no host IP change + host.update(h) + } else { + // host IP has changed + // remove old HostInfo (w/old IP) + r.session.removeHost(existing) + if _, alreadyExists := r.session.ring.addHostIfMissing(h); alreadyExists { + return fmt.Errorf("add new host=%s after removal: %w", h, ErrHostAlreadyExists) + } + // add new HostInfo (same hostID, new IP) + r.session.startPoolFill(h) + } } delete(prevHosts, h.HostID()) } - // TODO(zariel): it may be worth having a mutex covering the overall ring state - // in a session so that everything sees a consistent state. Becuase as is today - // events can come in and due to ordering an UP host could be removed from the cluster for _, host := range prevHosts { r.session.removeHost(host) } @@ -738,3 +837,161 @@ func (r *ringDescriber) refreshRing() error { r.session.policy.SetPartitioner(partitioner) return nil } + +const ( + ringRefreshDebounceTime = 1 * time.Second +) + +// debounces requests to call a refresh function (currently used for ring refresh). It also supports triggering a refresh immediately. +type refreshDebouncer struct { + mu sync.Mutex + stopped bool + broadcaster *errorBroadcaster + interval time.Duration + timer *time.Timer + refreshNowCh chan struct{} + quit chan struct{} + refreshFn func() error +} + +func newRefreshDebouncer(interval time.Duration, refreshFn func() error) *refreshDebouncer { + d := &refreshDebouncer{ + stopped: false, + broadcaster: nil, + refreshNowCh: make(chan struct{}, 1), + quit: make(chan struct{}), + interval: interval, + timer: time.NewTimer(interval), + refreshFn: refreshFn, + } + d.timer.Stop() + go d.flusher() + return d +} + +// debounces a request to call the refresh function +func (d *refreshDebouncer) debounce() { + d.mu.Lock() + defer d.mu.Unlock() + if d.stopped { + return + } + d.timer.Reset(d.interval) +} + +// requests an immediate refresh which will cancel pending refresh requests +func (d *refreshDebouncer) refreshNow() <-chan error { + d.mu.Lock() + defer d.mu.Unlock() + if d.broadcaster == nil { + d.broadcaster = newErrorBroadcaster() + select { + case d.refreshNowCh <- struct{}{}: + default: + // already a refresh pending + } + } + return d.broadcaster.newListener() +} + +func (d *refreshDebouncer) flusher() { + for { + select { + case <-d.refreshNowCh: + case <-d.timer.C: + case <-d.quit: + } + d.mu.Lock() + if d.stopped { + if d.broadcaster != nil { + d.broadcaster.stop() + d.broadcaster = nil + } + d.timer.Stop() + d.mu.Unlock() + return + } + + // make sure both request channels are cleared before we refresh + select { + case <-d.refreshNowCh: + default: + } + + d.timer.Stop() + select { + case <-d.timer.C: + default: + } + + curBroadcaster := d.broadcaster + d.broadcaster = nil + d.mu.Unlock() + + err := d.refreshFn() + if curBroadcaster != nil { + curBroadcaster.broadcast(err) + } + } +} + +func (d *refreshDebouncer) stop() { + d.mu.Lock() + if d.stopped { + d.mu.Unlock() + return + } + d.stopped = true + d.mu.Unlock() + d.quit <- struct{}{} // sync with flusher + close(d.quit) +} + +// broadcasts an error to multiple channels (listeners) +type errorBroadcaster struct { + listeners []chan<- error + mu sync.Mutex +} + +func newErrorBroadcaster() *errorBroadcaster { + return &errorBroadcaster{ + listeners: nil, + mu: sync.Mutex{}, + } +} + +func (b *errorBroadcaster) newListener() <-chan error { + ch := make(chan error, 1) + b.mu.Lock() + defer b.mu.Unlock() + b.listeners = append(b.listeners, ch) + return ch +} + +func (b *errorBroadcaster) broadcast(err error) { + b.mu.Lock() + defer b.mu.Unlock() + curListeners := b.listeners + if len(curListeners) > 0 { + b.listeners = nil + } else { + return + } + + for _, listener := range curListeners { + listener <- err + close(listener) + } +} + +func (b *errorBroadcaster) stop() { + b.mu.Lock() + defer b.mu.Unlock() + if len(b.listeners) == 0 { + return + } + for _, listener := range b.listeners { + close(listener) + } + b.listeners = nil +} diff --git a/vendor/github.com/gocql/gocql/host_source_gen.go b/vendor/github.com/gocql/gocql/host_source_gen.go index c82193cbd49..8c096ffd6de 100644 --- a/vendor/github.com/gocql/gocql/host_source_gen.go +++ b/vendor/github.com/gocql/gocql/host_source_gen.go @@ -1,3 +1,4 @@ +//go:build genhostinfo // +build genhostinfo package main diff --git a/vendor/github.com/gocql/gocql/integration.sh b/vendor/github.com/gocql/gocql/integration.sh index 749caa730d5..5c29615e957 100644 --- a/vendor/github.com/gocql/gocql/integration.sh +++ b/vendor/github.com/gocql/gocql/integration.sh @@ -8,26 +8,16 @@ readonly SCYLLA_IMAGE=${SCYLLA_IMAGE} set -eu -o pipefail function scylla_up() { - local -r exec="docker-compose exec -T" + local -r exec="docker compose exec -T" echo "==> Running Scylla ${SCYLLA_IMAGE}" docker pull ${SCYLLA_IMAGE} - docker-compose up -d - - echo "==> Waiting for CQL port" - for s in $(docker-compose ps --services); do - until v=$(${exec} ${s} cqlsh -e "DESCRIBE SCHEMA"); do - echo ${v} - docker-compose logs --tail 10 ${s} - sleep 5 - done - done - echo "==> Waiting for CQL port done" + docker compose up -d --wait } function scylla_down() { echo "==> Stopping Scylla" - docker-compose down + docker compose down } function scylla_restart() { diff --git a/vendor/github.com/gocql/gocql/internal/streams/streams.go b/vendor/github.com/gocql/gocql/internal/streams/streams.go index ea43412aac6..05bcd7d6a4e 100644 --- a/vendor/github.com/gocql/gocql/internal/streams/streams.go +++ b/vendor/github.com/gocql/gocql/internal/streams/streams.go @@ -24,6 +24,13 @@ func New(protocol int) *IDGenerator { if protocol > 2 { maxStreams = 32768 } + return NewLimited(maxStreams) +} + +func NewLimited(maxStreams int) *IDGenerator { + // Round up maxStreams to a nearest + // multiple of 64 + maxStreams = ((maxStreams + 63) / 64) * 64 buckets := maxStreams / 64 // reserve stream 0 diff --git a/vendor/github.com/gocql/gocql/marshal.go b/vendor/github.com/gocql/gocql/marshal.go index 7fa53fe83fe..898288a875b 100644 --- a/vendor/github.com/gocql/gocql/marshal.go +++ b/vendor/github.com/gocql/gocql/marshal.go @@ -51,45 +51,45 @@ type Unmarshaler interface { // // Supported conversions are as follows, other type combinations may be added in the future: // -// CQL type | Go type (value) | Note -// varchar, ascii, blob, text | string, []byte | -// boolean | bool | -// tinyint, smallint, int | integer types | -// tinyint, smallint, int | string | formatted as base 10 number -// bigint, counter | integer types | -// bigint, counter | big.Int | -// bigint, counter | string | formatted as base 10 number -// float | float32 | -// double | float64 | -// decimal | inf.Dec | -// time | int64 | nanoseconds since start of day -// time | time.Duration | duration since start of day -// timestamp | int64 | milliseconds since Unix epoch -// timestamp | time.Time | -// list, set | slice, array | -// list, set | map[X]struct{} | -// map | map[X]Y | -// uuid, timeuuid | gocql.UUID | -// uuid, timeuuid | [16]byte | raw UUID bytes -// uuid, timeuuid | []byte | raw UUID bytes, length must be 16 bytes -// uuid, timeuuid | string | hex representation, see ParseUUID -// varint | integer types | -// varint | big.Int | -// varint | string | value of number in decimal notation -// inet | net.IP | -// inet | string | IPv4 or IPv6 address string -// tuple | slice, array | -// tuple | struct | fields are marshaled in order of declaration -// user-defined type | gocql.UDTMarshaler | MarshalUDT is called -// user-defined type | map[string]interface{} | -// user-defined type | struct | struct fields' cql tags are used for column names -// date | int64 | milliseconds since Unix epoch to start of day (in UTC) -// date | time.Time | start of day (in UTC) -// date | string | parsed using "2006-01-02" format -// duration | int64 | duration in nanoseconds -// duration | time.Duration | -// duration | gocql.Duration | -// duration | string | parsed with time.ParseDuration +// CQL type | Go type (value) | Note +// varchar, ascii, blob, text | string, []byte | +// boolean | bool | +// tinyint, smallint, int | integer types | +// tinyint, smallint, int | string | formatted as base 10 number +// bigint, counter | integer types | +// bigint, counter | big.Int | +// bigint, counter | string | formatted as base 10 number +// float | float32 | +// double | float64 | +// decimal | inf.Dec | +// time | int64 | nanoseconds since start of day +// time | time.Duration | duration since start of day +// timestamp | int64 | milliseconds since Unix epoch +// timestamp | time.Time | +// list, set | slice, array | +// list, set | map[X]struct{} | +// map | map[X]Y | +// uuid, timeuuid | gocql.UUID | +// uuid, timeuuid | [16]byte | raw UUID bytes +// uuid, timeuuid | []byte | raw UUID bytes, length must be 16 bytes +// uuid, timeuuid | string | hex representation, see ParseUUID +// varint | integer types | +// varint | big.Int | +// varint | string | value of number in decimal notation +// inet | net.IP | +// inet | string | IPv4 or IPv6 address string +// tuple | slice, array | +// tuple | struct | fields are marshaled in order of declaration +// user-defined type | gocql.UDTMarshaler | MarshalUDT is called +// user-defined type | map[string]interface{} | +// user-defined type | struct | struct fields' cql tags are used for column names +// date | int64 | milliseconds since Unix epoch to start of day (in UTC) +// date | time.Time | start of day (in UTC) +// date | string | parsed using "2006-01-02" format +// duration | int64 | duration in nanoseconds +// duration | time.Duration | +// duration | gocql.Duration | +// duration | string | parsed with time.ParseDuration func Marshal(info TypeInfo, value interface{}) ([]byte, error) { if info.Version() < protoVersion1 { panic("protocol version not set") @@ -172,36 +172,36 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) { // // Supported conversions are as follows, other type combinations may be added in the future: // -// CQL type | Go type (value) | Note -// varchar, ascii, blob, text | *string | -// varchar, ascii, blob, text | *[]byte | non-nil buffer is reused -// bool | *bool | -// tinyint, smallint, int, bigint, counter | *integer types | -// tinyint, smallint, int, bigint, counter | *big.Int | -// tinyint, smallint, int, bigint, counter | *string | formatted as base 10 number -// float | *float32 | -// double | *float64 | -// decimal | *inf.Dec | -// time | *int64 | nanoseconds since start of day -// time | *time.Duration | -// timestamp | *int64 | milliseconds since Unix epoch -// timestamp | *time.Time | -// list, set | *slice, *array | -// map | *map[X]Y | -// uuid, timeuuid | *string | see UUID.String -// uuid, timeuuid | *[]byte | raw UUID bytes -// uuid, timeuuid | *gocql.UUID | -// timeuuid | *time.Time | timestamp of the UUID -// inet | *net.IP | -// inet | *string | IPv4 or IPv6 address string -// tuple | *slice, *array | -// tuple | *struct | struct fields are set in order of declaration -// user-defined types | gocql.UDTUnmarshaler | UnmarshalUDT is called -// user-defined types | *map[string]interface{} | -// user-defined types | *struct | cql tag is used to determine field name -// date | *time.Time | time of beginning of the day (in UTC) -// date | *string | formatted with 2006-01-02 format -// duration | *gocql.Duration | +// CQL type | Go type (value) | Note +// varchar, ascii, blob, text | *string | +// varchar, ascii, blob, text | *[]byte | non-nil buffer is reused +// bool | *bool | +// tinyint, smallint, int, bigint, counter | *integer types | +// tinyint, smallint, int, bigint, counter | *big.Int | +// tinyint, smallint, int, bigint, counter | *string | formatted as base 10 number +// float | *float32 | +// double | *float64 | +// decimal | *inf.Dec | +// time | *int64 | nanoseconds since start of day +// time | *time.Duration | +// timestamp | *int64 | milliseconds since Unix epoch +// timestamp | *time.Time | +// list, set | *slice, *array | +// map | *map[X]Y | +// uuid, timeuuid | *string | see UUID.String +// uuid, timeuuid | *[]byte | raw UUID bytes +// uuid, timeuuid | *gocql.UUID | +// timeuuid | *time.Time | timestamp of the UUID +// inet | *net.IP | +// inet | *string | IPv4 or IPv6 address string +// tuple | *slice, *array | +// tuple | *struct | struct fields are set in order of declaration +// user-defined types | gocql.UDTUnmarshaler | UnmarshalUDT is called +// user-defined types | *map[string]interface{} | +// user-defined types | *struct | cql tag is used to determine field name +// date | *time.Time | time of beginning of the day (in UTC) +// date | *string | formatted with 2006-01-02 format +// duration | *gocql.Duration | func Unmarshal(info TypeInfo, data []byte, value interface{}) error { if v, ok := value.(Unmarshaler); ok { return v.UnmarshalCQL(info, data) @@ -1174,6 +1174,9 @@ func unmarshalDecimal(info TypeInfo, data []byte, value interface{}) error { case Unmarshaler: return v.UnmarshalCQL(info, data) case *inf.Dec: + if len(data) < 4 { + return unmarshalErrorf("inf.Dec needs at least 4 bytes, while value has only %d", len(data)) + } scale := decInt(data[0:4]) unscaled := decBigInt2C(data[4:], nil) *v = *inf.NewDecBig(unscaled, inf.Scale(scale)) @@ -1331,6 +1334,8 @@ func unmarshalTimestamp(info TypeInfo, data []byte, value interface{}) error { return unmarshalErrorf("can not unmarshal %s into %T", info, value) } +const millisecondsInADay int64 = 24 * 60 * 60 * 1000 + func marshalDate(info TypeInfo, value interface{}) ([]byte, error) { var timestamp int64 switch v := value.(type) { @@ -1340,21 +1345,21 @@ func marshalDate(info TypeInfo, value interface{}) ([]byte, error) { return nil, nil case int64: timestamp = v - x := timestamp/86400000 + int64(1<<31) + x := timestamp/millisecondsInADay + int64(1<<31) return encInt(int32(x)), nil case time.Time: if v.IsZero() { return []byte{}, nil } timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) - x := timestamp/86400000 + int64(1<<31) + x := timestamp/millisecondsInADay + int64(1<<31) return encInt(int32(x)), nil case *time.Time: if v.IsZero() { return []byte{}, nil } timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) - x := timestamp/86400000 + int64(1<<31) + x := timestamp/millisecondsInADay + int64(1<<31) return encInt(int32(x)), nil case string: if v == "" { @@ -1365,7 +1370,7 @@ func marshalDate(info TypeInfo, value interface{}) ([]byte, error) { return nil, marshalErrorf("can not marshal %T into %s, date layout must be '2006-01-02'", value, info) } timestamp = int64(t.UTC().Unix()*1e3) + int64(t.UTC().Nanosecond()/1e6) - x := timestamp/86400000 + int64(1<<31) + x := timestamp/millisecondsInADay + int64(1<<31) return encInt(int32(x)), nil } @@ -1386,8 +1391,8 @@ func unmarshalDate(info TypeInfo, data []byte, value interface{}) error { } var origin uint32 = 1 << 31 var current uint32 = binary.BigEndian.Uint32(data) - timestamp := (int64(current) - int64(origin)) * 86400000 - *v = time.Unix(0, timestamp*int64(time.Millisecond)).In(time.UTC) + timestamp := (int64(current) - int64(origin)) * millisecondsInADay + *v = time.UnixMilli(timestamp).In(time.UTC) return nil case *string: if len(data) == 0 { @@ -1396,8 +1401,8 @@ func unmarshalDate(info TypeInfo, data []byte, value interface{}) error { } var origin uint32 = 1 << 31 var current uint32 = binary.BigEndian.Uint32(data) - timestamp := (int64(current) - int64(origin)) * 86400000 - *v = time.Unix(0, timestamp*int64(time.Millisecond)).In(time.UTC).Format("2006-01-02") + timestamp := (int64(current) - int64(origin)) * millisecondsInADay + *v = time.UnixMilli(timestamp).In(time.UTC).Format("2006-01-02") return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) @@ -1448,7 +1453,10 @@ func unmarshalDuration(info TypeInfo, data []byte, value interface{}) error { } return nil } - months, days, nanos := decVints(data) + months, days, nanos, err := decVints(data) + if err != nil { + return unmarshalErrorf("failed to unmarshal %s into %T: %s", info, value, err.Error()) + } *v = Duration{ Months: months, Days: days, @@ -1459,25 +1467,40 @@ func unmarshalDuration(info TypeInfo, data []byte, value interface{}) error { return unmarshalErrorf("can not unmarshal %s into %T", info, value) } -func decVints(data []byte) (int32, int32, int64) { - month, i := decVint(data) - days, j := decVint(data[i:]) - nanos, _ := decVint(data[i+j:]) - return int32(month), int32(days), nanos +func decVints(data []byte) (int32, int32, int64, error) { + month, i, err := decVint(data, 0) + if err != nil { + return 0, 0, 0, fmt.Errorf("failed to extract month: %s", err.Error()) + } + days, i, err := decVint(data, i) + if err != nil { + return 0, 0, 0, fmt.Errorf("failed to extract days: %s", err.Error()) + } + nanos, _, err := decVint(data, i) + if err != nil { + return 0, 0, 0, fmt.Errorf("failed to extract nanoseconds: %s", err.Error()) + } + return int32(month), int32(days), nanos, err } -func decVint(data []byte) (int64, int) { - firstByte := data[0] +func decVint(data []byte, start int) (int64, int, error) { + if len(data) <= start { + return 0, 0, errors.New("unexpected eof") + } + firstByte := data[start] if firstByte&0x80 == 0 { - return decIntZigZag(uint64(firstByte)), 1 + return decIntZigZag(uint64(firstByte)), start + 1, nil } numBytes := bits.LeadingZeros32(uint32(^firstByte)) - 24 ret := uint64(firstByte & (0xff >> uint(numBytes))) - for i := 0; i < numBytes; i++ { + if len(data) < start+numBytes+1 { + return 0, 0, fmt.Errorf("data expect to have %d bytes, but it has only %d", start+numBytes+1, len(data)) + } + for i := start; i < start+numBytes; i++ { ret <<= 8 ret |= uint64(data[i+1] & 0xff) } - return decIntZigZag(ret), numBytes + 1 + return decIntZigZag(ret), start + numBytes + 1, nil } func decIntZigZag(n uint64) int64 { @@ -1567,7 +1590,12 @@ func marshalList(info TypeInfo, value interface{}) ([]byte, error) { if err != nil { return nil, err } - if err := writeCollectionSize(listInfo, len(item), buf); err != nil { + itemLen := len(item) + // Set the value to null for supported protocols + if item == nil && listInfo.proto > protoVersion2 { + itemLen = -1 + } + if err := writeCollectionSize(listInfo, itemLen, buf); err != nil { return nil, err } buf.Write(item) @@ -1592,7 +1620,7 @@ func readCollectionSize(info CollectionType, data []byte) (size, read int, err e if len(data) < 4 { return 0, 0, unmarshalErrorf("unmarshal list: unexpected eof") } - size = int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3]) + size = int(int32(data[0])<<24 | int32(data[1])<<16 | int32(data[2])<<8 | int32(data[3])) read = 4 } else { if len(data) < 2 { @@ -1648,13 +1676,18 @@ func unmarshalList(info TypeInfo, data []byte, value interface{}) error { return err } data = data[p:] - if len(data) < m { - return unmarshalErrorf("unmarshal list: unexpected eof") + // In case m < 0, the value is null, and unmarshalData should be nil. + var unmarshalData []byte + if m >= 0 { + if len(data) < m { + return unmarshalErrorf("unmarshal list: unexpected eof") + } + unmarshalData = data[:m] + data = data[m:] } - if err := Unmarshal(listInfo.Elem, data[:m], rv.Index(i).Addr().Interface()); err != nil { + if err := Unmarshal(listInfo.Elem, unmarshalData, rv.Index(i).Addr().Interface()); err != nil { return err } - data = data[m:] } return nil } @@ -1697,7 +1730,12 @@ func marshalMap(info TypeInfo, value interface{}) ([]byte, error) { if err != nil { return nil, err } - if err := writeCollectionSize(mapInfo, len(item), buf); err != nil { + itemLen := len(item) + // Set the key to null for supported protocols + if item == nil && mapInfo.proto > protoVersion2 { + itemLen = -1 + } + if err := writeCollectionSize(mapInfo, itemLen, buf); err != nil { return nil, err } buf.Write(item) @@ -1706,7 +1744,12 @@ func marshalMap(info TypeInfo, value interface{}) ([]byte, error) { if err != nil { return nil, err } - if err := writeCollectionSize(mapInfo, len(item), buf); err != nil { + itemLen = len(item) + // Set the value to null for supported protocols + if item == nil && mapInfo.proto > protoVersion2 { + itemLen = -1 + } + if err := writeCollectionSize(mapInfo, itemLen, buf); err != nil { return nil, err } buf.Write(item) @@ -1733,11 +1776,14 @@ func unmarshalMap(info TypeInfo, data []byte, value interface{}) error { rv.Set(reflect.Zero(t)) return nil } - rv.Set(reflect.MakeMap(t)) n, p, err := readCollectionSize(mapInfo, data) if err != nil { return err } + if n < 0 { + return unmarshalErrorf("negative map size %d", n) + } + rv.Set(reflect.MakeMapWithSize(t, n)) data = data[p:] for i := 0; i < n; i++ { m, p, err := readCollectionSize(mapInfo, data) @@ -1745,28 +1791,39 @@ func unmarshalMap(info TypeInfo, data []byte, value interface{}) error { return err } data = data[p:] - if len(data) < m { - return unmarshalErrorf("unmarshal map: unexpected eof") - } key := reflect.New(t.Key()) - if err := Unmarshal(mapInfo.Key, data[:m], key.Interface()); err != nil { + // In case m < 0, the key is null, and unmarshalData should be nil. + var unmarshalData []byte + if m >= 0 { + if len(data) < m { + return unmarshalErrorf("unmarshal map: unexpected eof") + } + unmarshalData = data[:m] + data = data[m:] + } + if err := Unmarshal(mapInfo.Key, unmarshalData, key.Interface()); err != nil { return err } - data = data[m:] m, p, err = readCollectionSize(mapInfo, data) if err != nil { return err } data = data[p:] - if len(data) < m { - return unmarshalErrorf("unmarshal map: unexpected eof") - } val := reflect.New(t.Elem()) - if err := Unmarshal(mapInfo.Elem, data[:m], val.Interface()); err != nil { + + // In case m < 0, the value is null, and unmarshalData should be nil. + unmarshalData = nil + if m >= 0 { + if len(data) < m { + return unmarshalErrorf("unmarshal map: unexpected eof") + } + unmarshalData = data[:m] + data = data[m:] + } + if err := Unmarshal(mapInfo.Elem, unmarshalData, val.Interface()); err != nil { return err } - data = data[m:] rv.SetMapIndex(key.Elem(), val.Elem()) } @@ -2076,7 +2133,10 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error { p, data = readBytes(data) } - v := elem.New() + v, err := elem.NewWithError() + if err != nil { + return err + } if err := Unmarshal(elem, p, v); err != nil { return err } @@ -2110,7 +2170,10 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error { p, data = readBytes(data) } - v := elem.New() + v, err := elem.NewWithError() + if err != nil { + return err + } if err := Unmarshal(elem, p, v); err != nil { return err } @@ -2176,13 +2239,14 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) { var buf []byte for _, e := range udt.Elements { val, ok := v[e.Name] - if !ok { - return nil, marshalErrorf("marshal missing map key %q", e.Name) - } + var data []byte - data, err := Marshal(e.Type, val) - if err != nil { - return nil, err + if ok { + var err error + data, err = Marshal(e.Type, val) + if err != nil { + return nil, err + } } buf = appendBytes(buf, data) @@ -2242,14 +2306,16 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error { case UDTUnmarshaler: udt := info.(UDTTypeInfo) - for _, e := range udt.Elements { + for id, e := range udt.Elements { if len(data) == 0 { return nil } + if len(data) < 4 { + return unmarshalErrorf("can not unmarshal %s: field [%d]%s: unexpected eof", info, id, e.Name) + } var p []byte p, data = readBytes(data) - if err := v.UnmarshalUDT(e.Name, e.Type, p); err != nil { return err } @@ -2276,12 +2342,20 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error { rv.Set(reflect.MakeMap(t)) m := *v - for _, e := range udt.Elements { + for id, e := range udt.Elements { if len(data) == 0 { return nil } + if len(data) < 4 { + return unmarshalErrorf("can not unmarshal %s: field [%d]%s: unexpected eof", info, id, e.Name) + } + + valType, err := goType(e.Type) + if err != nil { + return unmarshalErrorf("can not unmarshal %s: %v", info, err) + } - val := reflect.New(goType(e.Type)) + val := reflect.New(valType) var p []byte p, data = readBytes(data) @@ -2296,7 +2370,11 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error { return nil } - k := reflect.ValueOf(value).Elem() + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return unmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + k := rv.Elem() if k.Kind() != reflect.Struct || !k.IsValid() { return unmarshalErrorf("cannot unmarshal %s into %T", info, value) } @@ -2320,10 +2398,13 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error { } udt := info.(UDTTypeInfo) - for _, e := range udt.Elements { + for id, e := range udt.Elements { + if len(data) == 0 { + return nil + } if len(data) < 4 { // UDT def does not match the column value - return nil + return unmarshalErrorf("can not unmarshal %s: field [%d]%s: unexpected eof", info, id, e.Name) } var p []byte @@ -2359,8 +2440,18 @@ type TypeInfo interface { Custom() string // New creates a pointer to an empty version of whatever type - // is referenced by the TypeInfo receiver + // is referenced by the TypeInfo receiver. + // + // If there is no corresponding Go type for the CQL type, New panics. + // + // Deprecated: Use NewWithError instead. New() interface{} + + // NewWithError creates a pointer to an empty version of whatever type + // is referenced by the TypeInfo receiver. + // + // If there is no corresponding Go type for the CQL type, NewWithError returns an error. + NewWithError() (interface{}, error) } type NativeType struct { @@ -2373,8 +2464,20 @@ func NewNativeType(proto byte, typ Type, custom string) NativeType { return NativeType{proto, typ, custom} } +func (t NativeType) NewWithError() (interface{}, error) { + typ, err := goType(t) + if err != nil { + return nil, err + } + return reflect.New(typ).Interface(), nil +} + func (t NativeType) New() interface{} { - return reflect.New(goType(t)).Interface() + val, err := t.NewWithError() + if err != nil { + panic(err.Error()) + } + return val } func (s NativeType) Type() Type { @@ -2404,8 +2507,20 @@ type CollectionType struct { Elem TypeInfo // only used for TypeMap, TypeList and TypeSet } +func (t CollectionType) NewWithError() (interface{}, error) { + typ, err := goType(t) + if err != nil { + return nil, err + } + return reflect.New(typ).Interface(), nil +} + func (t CollectionType) New() interface{} { - return reflect.New(goType(t)).Interface() + val, err := t.NewWithError() + if err != nil { + panic(err.Error()) + } + return val } func (c CollectionType) String() string { @@ -2437,8 +2552,20 @@ func (t TupleTypeInfo) String() string { return buf.String() } +func (t TupleTypeInfo) NewWithError() (interface{}, error) { + typ, err := goType(t) + if err != nil { + return nil, err + } + return reflect.New(typ).Interface(), nil +} + func (t TupleTypeInfo) New() interface{} { - return reflect.New(goType(t)).Interface() + val, err := t.NewWithError() + if err != nil { + panic(err.Error()) + } + return val } type UDTField struct { @@ -2453,8 +2580,20 @@ type UDTTypeInfo struct { Elements []UDTField } +func (u UDTTypeInfo) NewWithError() (interface{}, error) { + typ, err := goType(u) + if err != nil { + return nil, err + } + return reflect.New(typ).Interface(), nil +} + func (u UDTTypeInfo) New() interface{} { - return reflect.New(goType(u)).Interface() + val, err := u.NewWithError() + if err != nil { + panic(err.Error()) + } + return val } func (u UDTTypeInfo) String() string { diff --git a/vendor/github.com/gocql/gocql/metadata_cassandra.go b/vendor/github.com/gocql/gocql/metadata_cassandra.go index 6f833e442fa..bcf8b651e5a 100644 --- a/vendor/github.com/gocql/gocql/metadata_cassandra.go +++ b/vendor/github.com/gocql/gocql/metadata_cassandra.go @@ -136,7 +136,7 @@ type ColumnOrder bool const ( ASC ColumnOrder = false - DESC = true + DESC ColumnOrder = true ) type ColumnIndexMetadata struct { diff --git a/vendor/github.com/gocql/gocql/policies.go b/vendor/github.com/gocql/gocql/policies.go index 5305519c28b..6373a2c7c91 100644 --- a/vendor/github.com/gocql/gocql/policies.go +++ b/vendor/github.com/gocql/gocql/policies.go @@ -133,12 +133,11 @@ type RetryPolicy interface { // // See below for examples of usage: // -// //Assign to the cluster -// cluster.RetryPolicy = &gocql.SimpleRetryPolicy{NumRetries: 3} -// -// //Assign to a query -// query.RetryPolicy(&gocql.SimpleRetryPolicy{NumRetries: 1}) +// //Assign to the cluster +// cluster.RetryPolicy = &gocql.SimpleRetryPolicy{NumRetries: 3} // +// //Assign to a query +// query.RetryPolicy(&gocql.SimpleRetryPolicy{NumRetries: 1}) type SimpleRetryPolicy struct { NumRetries int //Number of times to retry a query } @@ -260,8 +259,23 @@ type KeyspaceUpdateEvent struct { Change string } +type HostTierer interface { + // HostTier returns an integer specifying how far a host is from the client. + // Tier must start at 0. + // The value is used to prioritize closer hosts during host selection. + // For example this could be: + // 0 - local rack, 1 - local DC, 2 - remote DC + // or: + // 0 - local DC, 1 - remote DC + HostTier(host *HostInfo) uint + + // This function returns the maximum possible host tier + MaxHostTier() uint +} + // HostSelectionPolicy is an interface for selecting // the most appropriate host to execute a given query. +// HostSelectionPolicy instances cannot be shared between sessions. type HostSelectionPolicy interface { HostStateNotifier SetPartitioner @@ -396,6 +410,13 @@ type tokenAwareHostPolicy struct { } func (t *tokenAwareHostPolicy) Init(s *Session) { + t.mu.Lock() + defer t.mu.Unlock() + if t.getKeyspaceMetadata != nil { + // Init was already called. + // See https://github.com/scylladb/gocql/issues/94. + panic("sharing token aware host selection policy between sessions is not supported") + } t.getKeyspaceMetadata = s.KeyspaceMetadata t.getKeyspaceName = func() string { return s.cfg.Keyspace } t.logger = s.logger @@ -583,18 +604,42 @@ func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost { var ( fallbackIter NextHost - i, j int - remote []*HostInfo + i, j, k int + remote [][]*HostInfo + tierer HostTierer + tiererOk bool + maxTier uint ) + if tierer, tiererOk = t.fallback.(HostTierer); tiererOk { + maxTier = tierer.MaxHostTier() + } else { + maxTier = 1 + } + + if t.nonLocalReplicasFallback { + remote = make([][]*HostInfo, maxTier) + } + used := make(map[*HostInfo]bool, len(replicas)) return func() SelectedHost { for i < len(replicas) { h := replicas[i] i++ - if !t.fallback.IsLocal(h) { - remote = append(remote, h) + var tier uint + if tiererOk { + tier = tierer.HostTier(h) + } else if t.fallback.IsLocal(h) { + tier = 0 + } else { + tier = 1 + } + + if tier != 0 { + if t.nonLocalReplicasFallback { + remote[tier-1] = append(remote[tier-1], h) + } continue } @@ -605,9 +650,14 @@ func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost { } if t.nonLocalReplicasFallback { - for j < len(remote) { - h := remote[j] - j++ + for j < len(remote) && k < len(remote[j]) { + h := remote[j][k] + k++ + + if k >= len(remote[j]) { + j++ + k = 0 + } if h.IsUp() { used[h] = true @@ -639,14 +689,13 @@ func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost { // use an empty slice of hosts as the hostpool will be populated later by gocql. // See below for examples of usage: // -// // Create host selection policy using a simple host pool -// cluster.PoolConfig.HostSelectionPolicy = HostPoolHostPolicy(hostpool.New(nil)) -// -// // Create host selection policy using an epsilon greedy pool -// cluster.PoolConfig.HostSelectionPolicy = HostPoolHostPolicy( -// hostpool.NewEpsilonGreedy(nil, 0, &hostpool.LinearEpsilonValueCalculator{}), -// ) +// // Create host selection policy using a simple host pool +// cluster.PoolConfig.HostSelectionPolicy = HostPoolHostPolicy(hostpool.New(nil)) // +// // Create host selection policy using an epsilon greedy pool +// cluster.PoolConfig.HostSelectionPolicy = HostPoolHostPolicy( +// hostpool.NewEpsilonGreedy(nil, 0, &hostpool.LinearEpsilonValueCalculator{}), +// ) func HostPoolHostPolicy(hp hostpool.HostPool) HostSelectionPolicy { return &hostPoolHostPolicy{hostMap: map[string]*HostInfo{}, hp: hp} } @@ -866,6 +915,68 @@ func (d *dcAwareRR) Pick(q ExecutableQuery) NextHost { return roundRobbin(int(nextStartOffset), d.localHosts.get(), d.remoteHosts.get()) } +// RackAwareRoundRobinPolicy is a host selection policies which will prioritize and +// return hosts which are in the local rack, before hosts in the local datacenter but +// a different rack, before hosts in all other datercentres + +type rackAwareRR struct { + // lastUsedHostIdx keeps the index of the last used host. + // It is accessed atomically and needs to be aligned to 64 bits, so we + // keep it first in the struct. Do not move it or add new struct members + // before it. + lastUsedHostIdx uint64 + localDC string + localRack string + hosts []cowHostList +} + +func RackAwareRoundRobinPolicy(localDC string, localRack string) HostSelectionPolicy { + hosts := make([]cowHostList, 3) + return &rackAwareRR{localDC: localDC, localRack: localRack, hosts: hosts} +} + +func (d *rackAwareRR) Init(*Session) {} +func (d *rackAwareRR) KeyspaceChanged(KeyspaceUpdateEvent) {} +func (d *rackAwareRR) SetPartitioner(p string) {} + +func (d *rackAwareRR) MaxHostTier() uint { + return 2 +} + +func (d *rackAwareRR) HostTier(host *HostInfo) uint { + if host.DataCenter() == d.localDC { + if host.Rack() == d.localRack { + return 0 + } else { + return 1 + } + } else { + return 2 + } +} + +func (d *rackAwareRR) IsLocal(host *HostInfo) bool { + return d.HostTier(host) == 0 +} + +func (d *rackAwareRR) AddHost(host *HostInfo) { + dist := d.HostTier(host) + d.hosts[dist].add(host) +} + +func (d *rackAwareRR) RemoveHost(host *HostInfo) { + dist := d.HostTier(host) + d.hosts[dist].remove(host.ConnectAddress()) +} + +func (d *rackAwareRR) HostUp(host *HostInfo) { d.AddHost(host) } +func (d *rackAwareRR) HostDown(host *HostInfo) { d.RemoveHost(host) } + +func (d *rackAwareRR) Pick(q ExecutableQuery) NextHost { + nextStartOffset := atomic.AddUint64(&d.lastUsedHostIdx, 1) + return roundRobbin(int(nextStartOffset), d.hosts[0].get(), d.hosts[1].get(), d.hosts[2].get()) +} + // ReadyPolicy defines a policy for when a HostSelectionPolicy can be used. After // each host connects during session initialization, the Ready method will be // called. If you only need a single Host to be up you can wrap a @@ -934,7 +1045,6 @@ func (e *SimpleConvictionPolicy) Reset(host *HostInfo) {} // ReconnectionPolicy interface is used by gocql to determine if reconnection // can be attempted after connection error. The interface allows gocql users // to implement their own logic to determine how to attempt reconnection. -// type ReconnectionPolicy interface { GetInterval(currentRetry int) time.Duration GetMaxRetries() int @@ -944,8 +1054,7 @@ type ReconnectionPolicy interface { // // Examples of usage: // -// cluster.ReconnectionPolicy = &gocql.ConstantReconnectionPolicy{MaxRetries: 10, Interval: 8 * time.Second} -// +// cluster.ReconnectionPolicy = &gocql.ConstantReconnectionPolicy{MaxRetries: 10, Interval: 8 * time.Second} type ConstantReconnectionPolicy struct { MaxRetries int Interval time.Duration diff --git a/vendor/github.com/gocql/gocql/query_executor.go b/vendor/github.com/gocql/gocql/query_executor.go index 9889316fbc3..e4dbed9cdc8 100644 --- a/vendor/github.com/gocql/gocql/query_executor.go +++ b/vendor/github.com/gocql/gocql/query_executor.go @@ -7,12 +7,15 @@ import ( ) type ExecutableQuery interface { + borrowForExecution() // Used to ensure that the query stays alive for lifetime of a particular execution goroutine. + releaseAfterExecution() // Used when a goroutine finishes its execution attempts, either with ok result or an error. execute(ctx context.Context, conn *Conn) *Iter attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) retryPolicy() RetryPolicy speculativeExecutionPolicy() SpeculativeExecutionPolicy GetRoutingKey() ([]byte, error) Keyspace() string + Table() string IsIdempotent() bool IsLWT() bool GetCustomPartitioner() partitioner @@ -45,6 +48,7 @@ func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp S for i := 0; i < sp.Attempts(); i++ { select { case <-ticker.C: + qry.borrowForExecution() // ensure liveness in case of executing Query to prevent races with Query.Release(). go q.run(ctx, qry, hostIter, results) case <-ctx.Done(): return &Iter{err: ctx.Err()} @@ -82,6 +86,7 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) { results := make(chan *Iter, 1) // Launch the main execution + qry.borrowForExecution() // ensure liveness in case of executing Query to prevent races with Query.Release(). go q.run(ctx, qry, hostIter, results) // The speculative executions are launched _in addition_ to the main @@ -173,4 +178,5 @@ func (q *queryExecutor) run(ctx context.Context, qry ExecutableQuery, hostIter N case results <- q.do(ctx, qry, hostIter): case <-ctx.Done(): } + qry.releaseAfterExecution() } diff --git a/vendor/github.com/gocql/gocql/recreate.go b/vendor/github.com/gocql/gocql/recreate.go index af4a585e19c..64a4a20bbb0 100644 --- a/vendor/github.com/gocql/gocql/recreate.go +++ b/vendor/github.com/gocql/gocql/recreate.go @@ -109,10 +109,10 @@ var functionTemplate = template.Must(template.New("functions"). "stripFrozen": cqlHelpers.stripFrozen, }). Parse(` -CREATE FUNCTION {{ escape .keyspaceName }}.{{ escape .fm.Name }} ( +CREATE FUNCTION {{ .keyspaceName }}.{{ .fm.Name }} ( {{- range $i, $args := zip .fm.ArgumentNames .fm.ArgumentTypes }} {{- if ne $i 0 }}, {{ end }} - {{- escape (index $args 0) }} + {{- (index $args 0) }} {{ stripFrozen (index $args 1) }} {{- end -}}) {{ if .fm.CalledOnNullInput }}CALLED{{ else }}RETURNS NULL{{ end }} ON NULL INPUT @@ -167,19 +167,19 @@ var aggregatesTemplate = template.Must(template.New("aggregate"). }). Parse(` CREATE AGGREGATE {{ .Keyspace }}.{{ .Name }}( - {{- range $arg, $i := .ArgumentTypes }} + {{- range $i, $arg := .ArgumentTypes }} {{- if ne $i 0 }}, {{ end }} {{ stripFrozen $arg }} {{- end -}}) SFUNC {{ .StateFunc.Name }} - STYPE {{ stripFrozen .State }} + STYPE {{ stripFrozen .StateType }} {{- if ne .FinalFunc.Name "" }} FINALFUNC {{ .FinalFunc.Name }} {{- end -}} {{- if ne .InitCond "" }} INITCOND {{ .InitCond }} {{- end -}} -); +; `)) func (km *KeyspaceMetadata) aggregateToCQL(w io.Writer, am *AggregateMetadata) error { diff --git a/vendor/github.com/gocql/gocql/ring.go b/vendor/github.com/gocql/gocql/ring.go index af296af2090..5b77370a160 100644 --- a/vendor/github.com/gocql/gocql/ring.go +++ b/vendor/github.com/gocql/gocql/ring.go @@ -26,8 +26,6 @@ type ring struct { } func (r *ring) rrHost() *HostInfo { - // TODO: should we filter hosts that get used here? These hosts will be used - // for the control connection, should we also provide an iterator? r.mu.RLock() defer r.mu.RUnlock() if len(r.hostList) == 0 { diff --git a/vendor/github.com/gocql/gocql/scylla.go b/vendor/github.com/gocql/gocql/scylla.go index 5fd8008ed46..7790a26eeb1 100644 --- a/vendor/github.com/gocql/gocql/scylla.go +++ b/vendor/github.com/gocql/gocql/scylla.go @@ -50,8 +50,58 @@ func findCQLProtoExtByName(exts []cqlProtocolExtension, name string) cqlProtocol // Each key identifies a single extension. const ( lwtAddMetadataMarkKey = "SCYLLA_LWT_ADD_METADATA_MARK" + rateLimitError = "SCYLLA_RATE_LIMIT_ERROR" ) +// "Rate limit" CQL Protocol Extension. +// This extension, if enabled (properly negotiated), allows Scylla server +// to send a special kind of error. +// +// Implements cqlProtocolExtension interface. +type rateLimitExt struct { + rateLimitErrorCode int +} + +var _ cqlProtocolExtension = &rateLimitExt{} + +// Factory function to deserialize and create an `rateLimitExt` instance +// from SUPPORTED message payload. +func newRateLimitExt(supported map[string][]string) *rateLimitExt { + const rateLimitErrorCode = "ERROR_CODE" + + if v, found := supported[rateLimitError]; found { + for i := range v { + splitVal := strings.Split(v[i], "=") + if splitVal[0] == rateLimitErrorCode { + var ( + err error + errorCode int + ) + if errorCode, err = strconv.Atoi(splitVal[1]); err != nil { + if gocqlDebug { + Logger.Printf("scylla: failed to parse %s value %v: %s", rateLimitErrorCode, splitVal[1], err) + return nil + } + } + return &rateLimitExt{ + rateLimitErrorCode: errorCode, + } + } + } + } + return nil +} + +func (ext *rateLimitExt) serialize() map[string]string { + return map[string]string{ + rateLimitError: "", + } +} + +func (ext *rateLimitExt) name() string { + return rateLimitError +} + // "LWT prepared statements metadata mark" CQL Protocol Extension. // This extension, if enabled (properly negotiated), allows Scylla server // to set a special bit in prepared statements metadata, which would indicate @@ -188,6 +238,11 @@ func parseCQLProtocolExtensions(supported map[string][]string) []cqlProtocolExte exts = append(exts, lwtExt) } + rateLimitExt := newRateLimitExt(supported) + if rateLimitExt != nil { + exts = append(exts, rateLimitExt) + } + return exts } diff --git a/vendor/github.com/gocql/gocql/session.go b/vendor/github.com/gocql/gocql/session.go index 3884b35c41e..f3058669e3b 100644 --- a/vendor/github.com/gocql/gocql/session.go +++ b/vendor/github.com/gocql/gocql/session.go @@ -41,7 +41,9 @@ type Session struct { batchObserver BatchObserver connectObserver ConnectObserver frameObserver FrameHeaderObserver + streamObserver StreamObserver hostSource *ringDescriber + ringRefresher *refreshDebouncer stmtsLRU *preparedLRU connCfg *ConnConfig @@ -72,8 +74,10 @@ type Session struct { // sessionStateMu protects isClosed and isInitialized. sessionStateMu sync.RWMutex - // isClosed is true once Session.Close is called. + // isClosed is true once Session.Close is finished. isClosed bool + // isClosing bool is true once Session.Close is started. + isClosing bool // isInitialized is true once Session.init succeeds. // you can use initialized() to read the value. isInitialized bool @@ -83,14 +87,14 @@ type Session struct { var queryPool = &sync.Pool{ New: func() interface{} { - return &Query{routingInfo: &queryRoutingInfo{}} + return &Query{routingInfo: &queryRoutingInfo{}, refCount: 1} }, } func addrsToHosts(addrs []string, defaultPort int, logger StdLogger) ([]*HostInfo, error) { var hosts []*HostInfo - for _, hostport := range addrs { - resolvedHosts, err := hostInfo(hostport, defaultPort) + for _, hostaddr := range addrs { + resolvedHosts, err := hostInfo(hostaddr, defaultPort) if err != nil { // Try other hosts if unable to resolve DNS name if _, ok := err.(*net.DNSError); ok { @@ -151,6 +155,7 @@ func NewSession(cfg ClusterConfig) (*Session, error) { s.routingKeyInfoCache.lru = lru.New(cfg.MaxRoutingKeyInfo) s.hostSource = &ringDescriber{session: s} + s.ringRefresher = newRefreshDebouncer(ringRefreshDebounceTime, func() error { return refreshRing(s.hostSource) }) if cfg.PoolConfig.HostSelectionPolicy == nil { cfg.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy() @@ -169,6 +174,7 @@ func NewSession(cfg ClusterConfig) (*Session, error) { s.batchObserver = cfg.BatchObserver s.connectObserver = cfg.ConnectObserver s.frameObserver = cfg.FrameHeaderObserver + s.streamObserver = cfg.StreamObserver //Check the TLS Config before trying to connect to anything external connCfg, err := connConfig(&s.cfg) @@ -474,11 +480,12 @@ func (s *Session) Bind(stmt string, b func(q *QueryInfo) ([]interface{}, error)) func (s *Session) Close() { s.sessionStateMu.Lock() - defer s.sessionStateMu.Unlock() - if s.isClosed { + if s.isClosing { + s.sessionStateMu.Unlock() return } - s.isClosed = true + s.isClosing = true + s.sessionStateMu.Unlock() if s.pool != nil { s.pool.Close() @@ -496,9 +503,17 @@ func (s *Session) Close() { s.schemaEvents.stop() } + if s.ringRefresher != nil { + s.ringRefresher.stop() + } + if s.cancel != nil { s.cancel() } + + s.sessionStateMu.Lock() + s.isClosed = true + s.sessionStateMu.Unlock() } func (s *Session) Closed() bool { @@ -628,8 +643,8 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI return nil, nil } - table := info.request.columns[0].Table - keyspace := info.request.columns[0].Keyspace + table := info.request.table + keyspace := info.request.keyspace partitioner, err := scyllaGetTablePartitioner(s, keyspace, table) if err != nil { @@ -650,6 +665,8 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI types: types, lwt: info.request.lwt, partitioner: partitioner, + keyspace: keyspace, + table: table, } inflight.value = routingKeyInfo @@ -685,6 +702,8 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI types: make([]TypeInfo, size), lwt: info.request.lwt, partitioner: partitioner, + keyspace: keyspace, + table: table, } for keyIndex, keyColumn := range partitionKey { @@ -910,6 +929,7 @@ type Query struct { idempotent bool customPayload map[string][]byte metrics *queryMetrics + refCount uint32 disableAutoPage bool @@ -935,6 +955,10 @@ type queryRoutingInfo struct { // If not nil, represents a custom partitioner for the table. partitioner partitioner + + keyspace string + + table string } func (qri *queryRoutingInfo) isLWT() bool { @@ -973,12 +997,18 @@ func (q Query) Statement() string { return q.stmt } +// Values returns the values passed in via Bind. +// This can be used by a wrapper type that needs to access the bound values. +func (q Query) Values() []interface{} { + return q.values +} + // String implements the stringer interface. func (q Query) String() string { return fmt.Sprintf("[query statement=%q values=%+v consistency=%s]", q.stmt, q.values, q.cons) } -//Attempts returns the number of times the query was executed. +// Attempts returns the number of times the query was executed. func (q *Query) Attempts() int { return q.metrics.attempts() } @@ -987,7 +1017,7 @@ func (q *Query) AddAttempts(i int, host *HostInfo) { q.metrics.attempt(i, 0, host, false) } -//Latency returns the average amount of nanoseconds per attempt of the query. +// Latency returns the average amount of nanoseconds per attempt of the query. func (q *Query) Latency() int64 { return q.metrics.latency() } @@ -1136,6 +1166,10 @@ func (q *Query) Keyspace() string { if q.getKeyspace != nil { return q.getKeyspace() } + if q.routingInfo.keyspace != "" { + return q.routingInfo.keyspace + } + if q.session == nil { return "" } @@ -1144,6 +1178,11 @@ func (q *Query) Keyspace() string { return q.session.cfg.Keyspace } +// Table returns name of the table the query will be executed against. +func (q *Query) Table() string { + return q.routingInfo.table +} + // GetRoutingKey gets the routing key to use for routing this query. If // a routing key has not been explicitly set, then the routing key will // be constructed if possible using the keyspace's schema and the query @@ -1165,10 +1204,13 @@ func (q *Query) GetRoutingKey() ([]byte, error) { if err != nil { return nil, err } + if routingKeyInfo != nil { q.routingInfo.mu.Lock() q.routingInfo.lwt = routingKeyInfo.lwt q.routingInfo.partitioner = routingKeyInfo.partitioner + q.routingInfo.keyspace = routingKeyInfo.keyspace + q.routingInfo.table = routingKeyInfo.table q.routingInfo.mu.Unlock() } return createRoutingKey(routingKeyInfo, q.values) @@ -1387,17 +1429,37 @@ func (q *Query) MapScanCAS(dest map[string]interface{}) (applied bool, err error // cannot be reused. // // Example: -// qry := session.Query("SELECT * FROM my_table") -// qry.Exec() -// qry.Release() +// +// qry := session.Query("SELECT * FROM my_table") +// qry.Exec() +// qry.Release() func (q *Query) Release() { - q.reset() - queryPool.Put(q) + q.decRefCount() } // reset zeroes out all fields of a query so that it can be safely pooled. func (q *Query) reset() { - *q = Query{routingInfo: &queryRoutingInfo{}} + *q = Query{routingInfo: &queryRoutingInfo{}, refCount: 1} +} + +func (q *Query) incRefCount() { + atomic.AddUint32(&q.refCount, 1) +} + +func (q *Query) decRefCount() { + if res := atomic.AddUint32(&q.refCount, ^uint32(0)); res == 0 { + // do release + q.reset() + queryPool.Put(q) + } +} + +func (q *Query) borrowForExecution() { + q.incRefCount() +} + +func (q *Query) releaseAfterExecution() { + q.decRefCount() } // Iter represents an iterator that can be used to iterate over all rows that @@ -1615,7 +1677,10 @@ func (iter *Iter) Scan(dest ...interface{}) bool { // custom QueryHandlers running in your C* cluster. // See https://datastax.github.io/java-driver/manual/custom_payloads/ func (iter *Iter) GetCustomPayload() map[string][]byte { - return iter.framer.customPayload + if iter.framer != nil { + return iter.framer.customPayload + } + return nil } // Warnings returns any warnings generated if given in the response from Cassandra. @@ -1773,6 +1838,11 @@ func (b *Batch) Keyspace() string { return b.keyspace } +// Batch has no reasonable eqivalent of Query.Table(). +func (b *Batch) Table() string { + return b.routingInfo.table +} + // Attempts returns the number of attempts made to execute the batch. func (b *Batch) Attempts() int { return b.metrics.attempts() @@ -1782,7 +1852,7 @@ func (b *Batch) AddAttempts(i int, host *HostInfo) { b.metrics.attempt(i, 0, host, false) } -//Latency returns the average number of nanoseconds to execute a single attempt of the batch. +// Latency returns the average number of nanoseconds to execute a single attempt of the batch. func (b *Batch) Latency() int64 { return b.metrics.latency() } @@ -2018,6 +2088,16 @@ func createRoutingKey(routingKeyInfo *routingKeyInfo, values []interface{}) ([]b return routingKey, nil } +func (b *Batch) borrowForExecution() { + // empty, because Batch has no equivalent of Query.Release() + // that would race with speculative executions. +} + +func (b *Batch) releaseAfterExecution() { + // empty, because Batch has no equivalent of Query.Release() + // that would race with speculative executions. +} + type BatchType byte const ( @@ -2051,8 +2131,10 @@ type routingKeyInfoLRU struct { } type routingKeyInfo struct { - indexes []int - types []TypeInfo + indexes []int + types []TypeInfo + keyspace string + table string lwt bool partitioner partitioner } @@ -2067,8 +2149,8 @@ func (r *routingKeyInfoLRU) Remove(key string) { r.mu.Unlock() } -//Max adjusts the maximum size of the cache and cleans up the oldest records if -//the new max is lower than the previous value. Not concurrency safe. +// Max adjusts the maximum size of the cache and cleans up the oldest records if +// the new max is lower than the previous value. Not concurrency safe. func (r *routingKeyInfoLRU) Max(max int) { r.mu.Lock() for r.lru.Len() > max { @@ -2127,6 +2209,7 @@ func (t *traceWriter) Trace(traceId []byte) { activity string source string elapsed int + thread string ) t.mu.Lock() @@ -2135,13 +2218,13 @@ func (t *traceWriter) Trace(traceId []byte) { fmt.Fprintf(t.w, "Tracing session %016x (coordinator: %s, duration: %v):\n", traceId, coordinator, time.Duration(duration)*time.Microsecond) - iter = t.session.control.query(`SELECT event_id, activity, source, source_elapsed + iter = t.session.control.query(`SELECT event_id, activity, source, source_elapsed, thread FROM system_traces.events WHERE session_id = ?`, traceId) - for iter.Scan(×tamp, &activity, &source, &elapsed) { - fmt.Fprintf(t.w, "%s: %s (source: %s, elapsed: %d)\n", - timestamp.Format("2006/01/02 15:04:05.999999"), activity, source, elapsed) + for iter.Scan(×tamp, &activity, &source, &elapsed, &thread) { + fmt.Fprintf(t.w, "%s: %s [%s] (source: %s, elapsed: %d)\n", + timestamp.Format("2006/01/02 15:04:05.999999"), activity, thread, source, elapsed) } if err := iter.Close(); err != nil { diff --git a/vendor/github.com/gocql/gocql/version.go b/vendor/github.com/gocql/gocql/version.go new file mode 100644 index 00000000000..015b40e1eed --- /dev/null +++ b/vendor/github.com/gocql/gocql/version.go @@ -0,0 +1,28 @@ +package gocql + +import "runtime/debug" + +const ( + mainModule = "github.com/gocql/gocql" +) + +var driverName string + +var driverVersion string + +func init() { + buildInfo, ok := debug.ReadBuildInfo() + if ok { + for _, d := range buildInfo.Deps { + if d.Path == mainModule { + driverName = mainModule + driverVersion = d.Version + if d.Replace != nil { + driverName = d.Replace.Path + driverVersion = d.Replace.Version + } + break + } + } + } +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 3295b593cd5..05c0174a806 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -324,7 +324,7 @@ github.com/go-task/slim-sprig # github.com/gobuffalo/flect v1.0.2 ## explicit; go 1.16 github.com/gobuffalo/flect -# github.com/gocql/gocql v1.6.0 => github.com/scylladb/gocql v1.7.3 +# github.com/gocql/gocql v1.6.0 => github.com/scylladb/gocql v1.12.0 ## explicit; go 1.13 github.com/gocql/gocql github.com/gocql/gocql/internal/lru @@ -1516,4 +1516,4 @@ sigs.k8s.io/structured-merge-diff/v4/value # sigs.k8s.io/yaml v1.3.0 ## explicit; go 1.12 sigs.k8s.io/yaml -# github.com/gocql/gocql => github.com/scylladb/gocql v1.7.3 +# github.com/gocql/gocql => github.com/scylladb/gocql v1.12.0