From 85843c5276c5aed954c4ead7881de03f8cffa89e Mon Sep 17 00:00:00 2001 From: Philipp Seith <6890503+philippseith@users.noreply.github.com> Date: Thu, 21 Oct 2021 18:44:49 +0200 Subject: [PATCH] fix: WebSocketConnection: Timeout on Read shuts down writing connection webSocketConnection Read and Write had separate timeouts which, when expired, mutually closed the websocket. In case only one party of the connection was sending, the websocket was closed after the timeout despite the fact that the connection was alive. This comes from misunderstanding the impact of the Read/Write context parameter. See https://github.com/nhooyr/websocket/issues/242 --- websocketconnection.go | 80 +++++++++++++++++++++++++++++++++--------- 1 file changed, 64 insertions(+), 16 deletions(-) diff --git a/websocketconnection.go b/websocketconnection.go index b72de126..43f56880 100644 --- a/websocketconnection.go +++ b/websocketconnection.go @@ -4,26 +4,29 @@ import ( "bytes" "context" "fmt" - "github.com/teivah/onecontext" "nhooyr.io/websocket" + "time" ) type webSocketConnection struct { ConnectionBase conn *websocket.Conn transferMode TransferMode + watchDogChan chan dogFood } func newWebSocketConnection(parentContext context.Context, requestContext context.Context, connectionID string, conn *websocket.Conn) *webSocketConnection { ctx, _ := onecontext.Merge(parentContext, requestContext) w := &webSocketConnection{ - conn: conn, + conn: conn, + watchDogChan: make(chan dogFood, 1), ConnectionBase: ConnectionBase{ ctx: ctx, connectionID: connectionID, }, } + go w.watchDog(ctx) return w } @@ -31,17 +34,11 @@ func (w *webSocketConnection) Write(p []byte) (n int, err error) { if err := w.Context().Err(); err != nil { return 0, fmt.Errorf("webSocketConnection canceled: %w", w.ctx.Err()) } - ctx := w.ctx - if w.timeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(w.ctx, w.Timeout()) - defer cancel() // has no effect because timeoutCtx is either done or not used anymore after websocket returns. But it keeps lint quiet - } messageType := websocket.MessageText if w.transferMode == BinaryTransferMode { messageType = websocket.MessageBinary } - err = w.conn.Write(ctx, messageType, p) + err = w.conn.Write(w.resetWatchDog(), messageType, p) if err != nil { return 0, err } @@ -52,19 +49,70 @@ func (w *webSocketConnection) Read(p []byte) (n int, err error) { if err := w.Context().Err(); err != nil { return 0, fmt.Errorf("webSocketConnection canceled: %w", w.ctx.Err()) } - ctx := w.ctx - if w.timeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(w.ctx, w.Timeout()) - defer cancel() // has no effect because timeoutCtx is either done or not used anymore after websocket returns. But it keeps lint quiet - } - _, data, err := w.conn.Read(ctx) + _, data, err := w.conn.Read(w.resetWatchDog()) if err != nil { return 0, err } return bytes.NewReader(data).Read(p) } +// resetWatchDog resets the common watchDog for Read and Write. +// the watchDog will stop waiting for the last set timeout and wait for the new timeout. +func (w *webSocketConnection) resetWatchDog() context.Context { + ctx := w.ctx + food := dogFood{timeout: w.timeout} + if w.timeout > 0 { + ctx, food.bark = context.WithCancel(w.ctx) + } + w.watchDogChan <- food + return ctx +} + +// dogFood is used to reset the watchDog +type dogFood struct { + // After this, the dog will bark + timeout time.Duration + bark context.CancelFunc +} + +// watchDog is the common watchDog for Read and Write. It stops the connection (aka closes the Websocket) +// when the last timeout has elapsed. If resetWatchDog is called before the last timeout has elapsed, +// the watchDog will restart waiting for the new timeout. If timeout is set to 0, it will not wait at all. +func (w *webSocketConnection) watchDog(ctx context.Context) { + var timer *time.Timer + var cancelTimeoutChan chan struct{} + for { + select { + case <-ctx.Done(): + return + case food := <-w.watchDogChan: + if timer != nil { + if !timer.Stop() { + go func() { + <-timer.C + }() + } + go func() { + cancelTimeoutChan <- struct{}{} + }() + } + if food.timeout != 0 { + timer = time.NewTimer(food.timeout) + cancelTimeoutChan = make(chan struct{}, 1) + go func() { + select { + case <-cancelTimeoutChan: + case <-timer.C: + food.bark() + } + }() + } else { + timer = nil + } + } + } +} + func (w *webSocketConnection) TransferMode() TransferMode { return w.transferMode }