Skip to content

Commit

Permalink
producer: add context.Context to Publish
Browse files Browse the repository at this point in the history
  • Loading branch information
Ulminator committed May 8, 2024
1 parent c2c3842 commit cd8a883
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 29 deletions.
47 changes: 35 additions & 12 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"compress/flate"
"context"
"crypto/tls"
"encoding/json"
"errors"
Expand Down Expand Up @@ -119,8 +120,7 @@ func NewConn(addr string, config *Config, delegate ConnDelegate) *Conn {
// The logger parameter is an interface that requires the following
// method to be implemented (such as the the stdlib log.Logger):
//
// Output(calldepth int, s string)
//
// Output(calldepth int, s string)
func (c *Conn) SetLogger(l logger, lvl LogLevel, format string) {
c.logGuard.Lock()
defer c.logGuard.Unlock()
Expand Down Expand Up @@ -171,13 +171,20 @@ func (c *Conn) getLogLevel() LogLevel {
// Connect dials and bootstraps the nsqd connection
// (including IDENTIFY) and returns the IdentifyResponse
func (c *Conn) Connect() (*IdentifyResponse, error) {
ctx := context.Background()
return c.ConnectWithContext(ctx)
}

func (c *Conn) ConnectWithContext(ctx context.Context) (*IdentifyResponse, error) {
dialer := &net.Dialer{
LocalAddr: c.config.LocalAddr,
Timeout: c.config.DialTimeout,
}

conn, err := dialer.Dial("tcp", c.addr)
// the timeout used is smallest of dialer.Timeout (config.DialTimeout) or context timeout
conn, err := dialer.DialContext(ctx, "tcp", c.addr)
if err != nil {
fmt.Println("dialer.DialContext error: ", err) // TODO: remove
return nil, err
}
c.conn = conn.(*net.TCPConn)
Expand All @@ -190,7 +197,7 @@ func (c *Conn) Connect() (*IdentifyResponse, error) {
return nil, fmt.Errorf("[%s] failed to write magic - %s", c.addr, err)
}

resp, err := c.identify()
resp, err := c.identify(ctx)
if err != nil {
return nil, err
}
Expand All @@ -200,7 +207,8 @@ func (c *Conn) Connect() (*IdentifyResponse, error) {
c.log(LogLevelError, "Auth Required")
return nil, errors.New("Auth Required")
}
err := c.auth(c.config.AuthSecret)
// should context passed into c.auth()?
err := c.auth(ctx, c.config.AuthSecret)
if err != nil {
c.log(LogLevelError, "Auth Failed %s", err)
return nil, err
Expand Down Expand Up @@ -291,13 +299,28 @@ func (c *Conn) Write(p []byte) (int, error) {
// WriteCommand is a goroutine safe method to write a Command
// to this connection, and flush.
func (c *Conn) WriteCommand(cmd *Command) error {
ctx := context.Background()
return c.WriteCommandWithContext(ctx, cmd)
}

func (c *Conn) WriteCommandWithContext(ctx context.Context, cmd *Command) error {
// would we want all of our usage of WriteCommand to be replaced with WriteCommandWithContext?
c.mtx.Lock()

_, err := cmd.WriteTo(c)
if err != nil {
var err error
select {
case <-ctx.Done():
fmt.Println("WriteCommandWithContext ctx.Done(): ", ctx.Err()) // TODO: remove
c.mtx.Unlock()
return ctx.Err()
default:
_, err := cmd.WriteTo(c)
if err != nil {
goto exit
}
err = c.Flush()
goto exit
}
err = c.Flush()

exit:
c.mtx.Unlock()
Expand All @@ -320,7 +343,7 @@ func (c *Conn) Flush() error {
return nil
}

func (c *Conn) identify() (*IdentifyResponse, error) {
func (c *Conn) identify(ctx context.Context) (*IdentifyResponse, error) {
ci := make(map[string]interface{})
ci["client_id"] = c.config.ClientID
ci["hostname"] = c.config.Hostname
Expand Down Expand Up @@ -350,7 +373,7 @@ func (c *Conn) identify() (*IdentifyResponse, error) {
return nil, ErrIdentify{err.Error()}
}

err = c.WriteCommand(cmd)
err = c.WriteCommandWithContext(ctx, cmd)
if err != nil {
return nil, ErrIdentify{err.Error()}
}
Expand Down Expand Up @@ -479,13 +502,13 @@ func (c *Conn) upgradeSnappy() error {
return nil
}

func (c *Conn) auth(secret string) error {
func (c *Conn) auth(ctx context.Context, secret string) error {
cmd, err := Auth(secret)
if err != nil {
return err
}

err = c.WriteCommand(cmd)
err = c.WriteCommandWithContext(ctx, cmd)
if err != nil {
return err
}
Expand Down
87 changes: 70 additions & 17 deletions producer.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package nsq

import (
"context"
"fmt"
"log"
"os"
Expand All @@ -15,8 +16,10 @@ type producerConn interface {
SetLoggerLevel(LogLevel)
SetLoggerForLevel(logger, LogLevel, string)
Connect() (*IdentifyResponse, error)
ConnectWithContext(context.Context) (*IdentifyResponse, error)
Close() error
WriteCommand(*Command) error
WriteCommandWithContext(context.Context, *Command) error
}

// Producer is a high-level type to publish to NSQ.
Expand Down Expand Up @@ -53,6 +56,7 @@ type Producer struct {
// to retrieve metadata about the command after the
// response is received.
type ProducerTransaction struct {
ctx context.Context
cmd *Command
doneChan chan *ProducerTransaction
Error error // the error (or nil) of the publish command
Expand Down Expand Up @@ -105,23 +109,27 @@ func NewProducer(addr string, config *Config) (*Producer, error) {
// configured correctly, rather than relying on the lazy "connect on Publish"
// behavior of a Producer.
func (w *Producer) Ping() error {
ctx := context.Background()
return w.PingWithContext(ctx)
}

func (w *Producer) PingWithContext(ctx context.Context) error {
if atomic.LoadInt32(&w.state) != StateConnected {
err := w.connect()
err := w.connect(ctx)
if err != nil {
return err
}
}

return w.conn.WriteCommand(Nop())
return w.conn.WriteCommandWithContext(ctx, Nop())
}

// SetLogger assigns the logger to use as well as a level
//
// The logger parameter is an interface that requires the following
// method to be implemented (such as the the stdlib log.Logger):
//
// Output(calldepth int, s string)
//
// Output(calldepth int, s string)
func (w *Producer) SetLogger(l logger, lvl LogLevel) {
w.logGuard.Lock()
defer w.logGuard.Unlock()
Expand Down Expand Up @@ -192,7 +200,13 @@ func (w *Producer) Stop() {
// and the response error if present
func (w *Producer) PublishAsync(topic string, body []byte, doneChan chan *ProducerTransaction,
args ...interface{}) error {
return w.sendCommandAsync(Publish(topic, body), doneChan, args)
ctx := context.Background()
return w.PublishAsyncWithContext(ctx, topic, body, doneChan, args...)
}

func (w *Producer) PublishAsyncWithContext(ctx context.Context, topic string, body []byte, doneChan chan *ProducerTransaction,
args ...interface{}) error {
return w.sendCommandAsync(ctx, Publish(topic, body), doneChan, args)
}

// MultiPublishAsync publishes a slice of message bodies to the specified topic
Expand All @@ -203,35 +217,56 @@ func (w *Producer) PublishAsync(topic string, body []byte, doneChan chan *Produc
// will receive a `ProducerTransaction` instance with the supplied variadic arguments
// and the response error if present
func (w *Producer) MultiPublishAsync(topic string, body [][]byte, doneChan chan *ProducerTransaction,
args ...interface{}) error {
ctx := context.Background()
return w.MultiPublishAsyncWithContext(ctx, topic, body, doneChan, args...)
}

func (w *Producer) MultiPublishAsyncWithContext(ctx context.Context, topic string, body [][]byte, doneChan chan *ProducerTransaction,
args ...interface{}) error {
cmd, err := MultiPublish(topic, body)
if err != nil {
return err
}
return w.sendCommandAsync(cmd, doneChan, args)
return w.sendCommandAsync(ctx, cmd, doneChan, args)
}

// Publish synchronously publishes a message body to the specified topic, returning
// an error if publish failed
func (w *Producer) Publish(topic string, body []byte) error {
return w.sendCommand(Publish(topic, body))
ctx := context.Background()
return w.PublishWithContext(ctx, topic, body)
}

func (w *Producer) PublishWithContext(ctx context.Context, topic string, body []byte) error {
return w.sendCommand(ctx, Publish(topic, body))
}

// MultiPublish synchronously publishes a slice of message bodies to the specified topic, returning
// an error if publish failed
func (w *Producer) MultiPublish(topic string, body [][]byte) error {
ctx := context.Background()
return w.MultiPublishWithContext(ctx, topic, body)
}

func (w *Producer) MultiPublishWithContext(ctx context.Context, topic string, body [][]byte) error {
cmd, err := MultiPublish(topic, body)
if err != nil {
return err
}
return w.sendCommand(cmd)
return w.sendCommand(ctx, cmd)
}

// DeferredPublish synchronously publishes a message body to the specified topic
// where the message will queue at the channel level until the timeout expires, returning
// an error if publish failed
func (w *Producer) DeferredPublish(topic string, delay time.Duration, body []byte) error {
return w.sendCommand(DeferredPublish(topic, delay, body))
ctx := context.Background()
return w.DeferredPublishWithContext(ctx, topic, delay, body)
}

func (w *Producer) DeferredPublishWithContext(ctx context.Context, topic string, delay time.Duration, body []byte) error {
return w.sendCommand(ctx, DeferredPublish(topic, delay, body))
}

// DeferredPublishAsync publishes a message body to the specified topic
Expand All @@ -244,12 +279,18 @@ func (w *Producer) DeferredPublish(topic string, delay time.Duration, body []byt
// and the response error if present
func (w *Producer) DeferredPublishAsync(topic string, delay time.Duration, body []byte,
doneChan chan *ProducerTransaction, args ...interface{}) error {
return w.sendCommandAsync(DeferredPublish(topic, delay, body), doneChan, args)
ctx := context.Background()
return w.DeferredPublishAsyncWithContext(ctx, topic, delay, body, doneChan, args...)
}

func (w *Producer) DeferredPublishAsyncWithContext(ctx context.Context, topic string, delay time.Duration, body []byte,
doneChan chan *ProducerTransaction, args ...interface{}) error {
return w.sendCommandAsync(ctx, DeferredPublish(topic, delay, body), doneChan, args)
}

func (w *Producer) sendCommand(cmd *Command) error {
func (w *Producer) sendCommand(ctx context.Context, cmd *Command) error {
doneChan := make(chan *ProducerTransaction)
err := w.sendCommandAsync(cmd, doneChan, nil)
err := w.sendCommandAsync(ctx, cmd, doneChan, nil)
if err != nil {
close(doneChan)
return err
Expand All @@ -258,21 +299,22 @@ func (w *Producer) sendCommand(cmd *Command) error {
return t.Error
}

func (w *Producer) sendCommandAsync(cmd *Command, doneChan chan *ProducerTransaction,
func (w *Producer) sendCommandAsync(ctx context.Context, cmd *Command, doneChan chan *ProducerTransaction,
args []interface{}) error {
// keep track of how many outstanding producers we're dealing with
// in order to later ensure that we clean them all up...
atomic.AddInt32(&w.concurrentProducers, 1)
defer atomic.AddInt32(&w.concurrentProducers, -1)

if atomic.LoadInt32(&w.state) != StateConnected {
err := w.connect()
err := w.connect(ctx)
if err != nil {
return err
}
}

t := &ProducerTransaction{
ctx: ctx,
cmd: cmd,
doneChan: doneChan,
Args: args,
Expand All @@ -282,12 +324,15 @@ func (w *Producer) sendCommandAsync(cmd *Command, doneChan chan *ProducerTransac
case w.transactionChan <- t:
case <-w.exitChan:
return ErrStopped
case <-ctx.Done():
fmt.Println("sendCommandAsync ctx.Done(): ", ctx.Err()) // TODO: remove
return ctx.Err()
}

return nil
}

func (w *Producer) connect() error {
func (w *Producer) connect(ctx context.Context) error {
w.guard.Lock()
defer w.guard.Unlock()

Expand All @@ -312,7 +357,7 @@ func (w *Producer) connect() error {
w.conn.SetLoggerForLevel(w.logger[index], LogLevel(index), format)
}

_, err := w.conn.Connect()
_, err := w.conn.ConnectWithContext(ctx)
if err != nil {
w.conn.Close()
w.log(LogLevelError, "(%s) error connecting to nsqd - %s", w.addr, err)
Expand Down Expand Up @@ -344,9 +389,17 @@ func (w *Producer) router() {
select {
case t := <-w.transactionChan:
w.transactions = append(w.transactions, t)
err := w.conn.WriteCommand(t.cmd)
err := w.conn.WriteCommandWithContext(t.ctx, t.cmd)
if err != nil {
w.log(LogLevelError, "(%s) sending command - %s", w.conn.String(), err)
if err == context.Canceled || err == context.DeadlineExceeded {
// keep the connection alive if related to context timeout
// need to do some stuff that's in Producer.popTransaction here
w.transactions = w.transactions[1:]
t.Error = err
t.finish()
continue
}
w.close()
}
case data := <-w.responseChan:
Expand Down
Loading

0 comments on commit cd8a883

Please sign in to comment.