Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: properly stop client on Stop #188

Merged
merged 5 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ type client struct {
func (c *client) Start() {
c.setState(ClientConnecting)
boff := c.backoffFactory()
c.partyBase.waitGroup().Add(1)
go func() {
defer c.partyBase.waitGroup().Done()
for {
c.setErr(nil)
// Listen for state change to ClientConnected and signal backoff Reset then.
Expand Down Expand Up @@ -221,8 +223,9 @@ func (c *client) Start() {
func (c *client) Stop() {
if c.cancelFunc != nil {
c.cancelFunc()
c.partyBase.waitGroup().Wait()
c.setState(ClientClosed)
}
c.setState(ClientClosed)
}

func (c *client) run() error {
Expand Down
70 changes: 70 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,18 @@ func (s *simpleReceiver) OnCallback(result string) {
s.ch <- result
}

type noLogAfterStopLogger struct {
StructuredLogger
shouldPanic atomic.Bool
}

func (n *noLogAfterStopLogger) Log(keyVals ...interface{}) error {
if n.shouldPanic.Load() {
panic("oh no")
}
return n.StructuredLogger.Log(keyVals)
}

var _ = Describe("Client", func() {
formatOption := TransferFormat("Text")
j := 1
Expand Down Expand Up @@ -149,6 +161,64 @@ var _ = Describe("Client", func() {
close(done)
}, 1.0)
})
Context("Stop", func() {
It("should stop the client properly", func(done Done) {
// Create a simple server
server, err := NewServer(context.TODO(), SimpleHubFactory(&simpleHub{}),
testLoggerOption(),
ChanReceiveTimeout(200*time.Millisecond),
StreamBufferCapacity(5))
Expect(err).NotTo(HaveOccurred())
Expect(server).NotTo(BeNil())
// Create both ends of the connection
cliConn, srvConn := newClientServerConnections()
// Start the server
go func() { _ = server.Serve(srvConn) }()
// Create the Client
clientConn, err := NewClient(context.Background(), WithConnection(cliConn), testLoggerOption(), formatOption)
Expect(err).NotTo(HaveOccurred())
Expect(clientConn).NotTo(BeNil())
// Start it
clientConn.Start()
Expect(<-clientConn.WaitForState(context.Background(), ClientConnected)).NotTo(HaveOccurred())
clientConn.Stop()
Expect(clientConn.State()).To(BeEquivalentTo(ClientClosed))
server.cancel()
close(done)
})
It("should not log after stop", func(done Done) {
// Create a simple server
server, err := NewServer(context.TODO(), SimpleHubFactory(&simpleHub{}),
testLoggerOption(),
ChanReceiveTimeout(200*time.Millisecond),
StreamBufferCapacity(5))
Expect(err).NotTo(HaveOccurred())
Expect(server).NotTo(BeNil())
// Create both ends of the connection
cliConn, srvConn := newClientServerConnections()
// Start the server
go func() { _ = server.Serve(srvConn) }()
// Create the Client
clientConn, err := NewClient(context.Background(), WithConnection(cliConn), testLoggerOption(), formatOption)
Expect(err).NotTo(HaveOccurred())
Expect(clientConn).NotTo(BeNil())
// Replace loggers with loggers that panic after stop
info, debug := clientConn.loggers()
panicableInfo, panicableDebug := &noLogAfterStopLogger{StructuredLogger: info}, &noLogAfterStopLogger{StructuredLogger: debug}
clientConn.setLoggers(panicableInfo, panicableDebug)
// Start it
clientConn.Start()
Expect(<-clientConn.WaitForState(context.Background(), ClientConnected)).NotTo(HaveOccurred())
clientConn.Stop()
panicableInfo.shouldPanic.Store(true)
panicableDebug.shouldPanic.Store(true)
// Ensure that we really don't get any logs anymore
time.Sleep(500 * time.Millisecond)
Expect(clientConn.State()).To(BeEquivalentTo(ClientClosed))
server.cancel()
close(done)
})
})
Context("Invoke", func() {
It("should invoke a server method and return the result", func(done Done) {
_, client, _, cancelClient := getTestBed(&simpleReceiver{}, formatOption)
Expand Down
3 changes: 3 additions & 0 deletions loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ func (l *loop) Run(connected chan struct{}) (err error) {
close(connected)
// Process messages
ch := make(chan receiveResult, 1)
wg := l.party.waitGroup()
wg.Add(1)
go func() {
defer wg.Done()
recv := l.hubConn.Receive()
loop:
for {
Expand Down
20 changes: 14 additions & 6 deletions party.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package signalr

import (
"context"
"sync"
"time"

"github.com/go-kit/log"
Expand Down Expand Up @@ -29,7 +30,7 @@ type Party interface {
insecureSkipVerify() bool
setInsecureSkipVerify(skip bool)

originPatterns() [] string
originPatterns() []string
setOriginPatterns(orgs []string)

chanReceiveTimeout() time.Duration
Expand All @@ -50,6 +51,8 @@ type Party interface {

maximumReceiveMessageSize() uint
setMaximumReceiveMessageSize(size uint)

waitGroup() *sync.WaitGroup
}

func newPartyBase(parentContext context.Context, info log.Logger, dbg log.Logger) partyBase {
Expand Down Expand Up @@ -81,10 +84,11 @@ type partyBase struct {
_streamBufferCapacity uint
_maximumReceiveMessageSize uint
_enableDetailedErrors bool
_insecureSkipVerify bool
_originPatterns []string
_insecureSkipVerify bool
_originPatterns []string
info StructuredLogger
dbg StructuredLogger
wg sync.WaitGroup
}

func (p *partyBase) context() context.Context {
Expand Down Expand Up @@ -120,16 +124,16 @@ func (p *partyBase) setKeepAliveInterval(interval time.Duration) {
}

func (p *partyBase) insecureSkipVerify() bool {
return p._insecureSkipVerify
return p._insecureSkipVerify
}
func (p *partyBase) setInsecureSkipVerify(skip bool) {
p._insecureSkipVerify = skip
}

func (p *partyBase) originPatterns() []string {
return p._originPatterns
return p._originPatterns
}
func (p *partyBase) setOriginPatterns(origins []string) {
func (p *partyBase) setOriginPatterns(origins []string) {
p._originPatterns = origins
}

Expand Down Expand Up @@ -173,3 +177,7 @@ func (p *partyBase) setLoggers(info StructuredLogger, dbg StructuredLogger) {
func (p *partyBase) loggers() (info StructuredLogger, debug StructuredLogger) {
return p.info, p.dbg
}

func (p *partyBase) waitGroup() *sync.WaitGroup {
return &p.wg
}
Loading