Skip to content

Commit

Permalink
Merge pull request #135 from philippseith/bugfix/#133_PushStreams_sen…
Browse files Browse the repository at this point in the history
…ds_StreamInvocation

Bugfix/#133 push streams sends stream invocation
  • Loading branch information
philippseith authored Aug 22, 2022
2 parents 7231a2c + fa93e6c commit a16987d
Show file tree
Hide file tree
Showing 8 changed files with 4,900 additions and 73 deletions.
22 changes: 11 additions & 11 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ type Client interface {
Invoke(method string, arguments ...interface{}) <-chan InvokeResult
Send(method string, arguments ...interface{}) <-chan error
PullStream(method string, arguments ...interface{}) <-chan InvokeResult
PushStreams(method string, arguments ...interface{}) <-chan error
PushStreams(method string, arguments ...interface{}) <-chan InvokeResult
}

var ErrUnableToConnect = errors.New("neither WithConnection nor WithConnector option was given")
Expand Down Expand Up @@ -462,28 +462,28 @@ func (c *client) PullStream(method string, arguments ...interface{}) <-chan Invo
return irCh
}

func (c *client) PushStreams(method string, arguments ...interface{}) <-chan error {
errCh := make(chan error, 1)
func (c *client) PushStreams(method string, arguments ...interface{}) <-chan InvokeResult {
irCh := make(chan InvokeResult, 1)
go func() {
if err := <-c.waitForConnected(); err != nil {
errCh <- err
close(errCh)
irCh <- InvokeResult{Error: err}
close(irCh)
return
}
pushCh, err := c.loop.PushStreams(method, c.loop.GetNewID(), arguments...)
if err != nil {
errCh <- err
close(errCh)
irCh <- InvokeResult{Error: err}
close(irCh)
return
}
go func() {
for err := range pushCh {
errCh <- err
for ir := range pushCh {
irCh <- ir
}
close(errCh)
close(irCh)
}()
}()
return errCh
return irCh
}

func (c *client) waitForConnected() <-chan error {
Expand Down
29 changes: 15 additions & 14 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,8 @@ func newClientServerConnections() (cliConn *pipeConnection, svrConn *pipeConnect

type simpleHub struct {
Hub
receiveStreamArg string
receiveStreamChanValues []int
receiveStreamDone chan struct{}
receiveStreamArg string
receiveStreamDone chan struct{}
}

func (s *simpleHub) InvokeMe(arg1 string, arg2 int) string {
Expand All @@ -98,15 +97,14 @@ func (s *simpleHub) ReadStream(i int) chan string {
return ch
}

func (s *simpleHub) ReceiveStream(arg string, ch <-chan int) {
func (s *simpleHub) ReceiveStream(arg string, ch <-chan int) int {
s.receiveStreamArg = arg
s.receiveStreamChanValues = make([]int, 0)
go func(ch <-chan int, done chan struct{}) {
for v := range ch {
s.receiveStreamChanValues = append(s.receiveStreamChanValues, v)
}
done <- struct{}{}
}(ch, s.receiveStreamDone)
receiveStreamChanValues := make([]int, 0)
for v := range ch {
receiveStreamChanValues = append(receiveStreamChanValues, v)
}
s.receiveStreamDone <- struct{}{}
return 100
}

func (s *simpleHub) Abort() {
Expand Down Expand Up @@ -339,14 +337,17 @@ var _ = Describe("Client", func() {

It("should push a stream to the server", func(done Done) {
ch := make(chan int, 1)
_ = client.PushStreams("ReceiveStream", "test", ch)
r := client.PushStreams("ReceiveStream", "test", ch)
go func(ch chan int) {
for i := 1; i < 5; i++ {
ch <- i
}
close(ch)
}(ch)
<-hub.receiveStreamDone
ir := <-r
Expect(ir.Error).To(BeNil())
Expect(ir.Value).To(Equal(float64(100)))
Expect(hub.receiveStreamArg).To(Equal("test"))
cancelClient()
close(done)
Expand All @@ -355,8 +356,8 @@ var _ = Describe("Client", func() {
It("should return an error when the connection fails", func(done Done) {
cliConn.fail.Store(errors.New("fail"))
ch := make(chan int, 1)
err := <-client.PushStreams("ReceiveStream", "test", ch)
Expect(err).To(HaveOccurred())
ir := <-client.PushStreams("ReceiveStream", "test", ch)
Expect(ir.Error).To(HaveOccurred())
cancelClient()
close(done)
}, 1.0)
Expand Down
9 changes: 6 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@ go 1.16
require (
github.com/cenkalti/backoff/v4 v4.1.2
github.com/dave/jennifer v1.4.1
github.com/fsnotify/fsnotify v1.5.1 // indirect
github.com/go-kit/log v0.2.0
github.com/google/uuid v1.3.0
github.com/klauspost/compress v1.13.6 // indirect
github.com/onsi/ginkgo v1.12.1
github.com/onsi/gomega v1.11.0
github.com/stretchr/testify v1.7.0
github.com/teivah/onecontext v1.3.0
github.com/vmihailenco/msgpack/v5 v5.3.5
nhooyr.io/websocket v1.8.7
)

require (
github.com/fsnotify/fsnotify v1.5.1 // indirect
github.com/klauspost/compress v1.13.6 // indirect
golang.org/x/net v0.0.0-20211111160137-58aab5ef257a // indirect
golang.org/x/sys v0.0.0-20211112164355-7580c6e521dc // indirect
golang.org/x/text v0.3.7 // indirect
nhooyr.io/websocket v1.8.7
)
6 changes: 3 additions & 3 deletions logger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package signalr
import (
"context"
"encoding/json"
"io/ioutil"
"io"
"os"
"sync"
"testing"
Expand All @@ -29,14 +29,14 @@ func testLoggerOption() func(Party) error {
func testLogger() StructuredLogger {
if tLog == nil {
lConf = loggerConfig{Enabled: false, Debug: false}
b, err := ioutil.ReadFile("testLogConf.json")
b, err := os.ReadFile("testLogConf.json")
if err == nil {
err = json.Unmarshal(b, &lConf)
if err != nil {
lConf = loggerConfig{Enabled: false, Debug: false}
}
}
writer := ioutil.Discard
writer := io.Discard
if lConf.Enabled {
writer = os.Stderr
}
Expand Down
31 changes: 13 additions & 18 deletions loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,9 @@ func (l *loop) PullStream(method, id string, arguments ...interface{}) <-chan In
return ch
}

func (l *loop) PushStreams(method, id string, arguments ...interface{}) (<-chan error, error) {
_, errChan := l.invokeClient.newInvocation(id)
func (l *loop) PushStreams(method, id string, arguments ...interface{}) (<-chan InvokeResult, error) {
resultCh, errCh := l.invokeClient.newInvocation(id)
irCh := newInvokeResultChan(l.party.context(), resultCh, errCh)
invokeArgs := make([]interface{}, 0)
reflectedChannels := make([]reflect.Value, 0)
streamIds := make([]string, 0)
Expand All @@ -162,15 +163,15 @@ func (l *loop) PushStreams(method, id string, arguments ...interface{}) (<-chan
}
}
// Tell the server we are streaming now
if err := l.hubConn.SendInvocationWithStreamIds(l.GetNewID(), method, invokeArgs, streamIds); err != nil {
if err := l.hubConn.SendInvocationWithStreamIds(id, method, invokeArgs, streamIds); err != nil {
l.invokeClient.deleteInvocation(id)
return nil, err
}
// Start streaming on all channels
for i, reflectedChannel := range reflectedChannels {
l.streamer.Start(streamIds[i], reflectedChannel)
}
return errChan, nil
return irCh, nil
}

// GetNewID returns a new, connection-unique id for invocations and streams
Expand All @@ -186,16 +187,10 @@ func (l *loop) handleInvocationMessage(invocation invocationMessage) {
// Unable to find the method
_ = l.info.Log(evt, "getMethod", "error", "missing method", "name", invocation.Target, react, "send completion with error")
_ = l.hubConn.Completion(invocation.InvocationID, nil, fmt.Sprintf("Unknown method %s", invocation.Target))
} else if in, clientStreaming, err := buildMethodArguments(method, invocation, l.streamClient, l.protocol); err != nil {
} else if in, err := buildMethodArguments(method, invocation, l.streamClient, l.protocol); err != nil {
// argument build failed
_ = l.info.Log(evt, "buildMethodArguments", "error", err, "name", invocation.Target, react, "send completion with error")
_ = l.hubConn.Completion(invocation.InvocationID, nil, err.Error())
} else if clientStreaming {
// let the receiving method run independently
go func() {
defer l.recoverInvocationPanic(invocation)
method.Call(in)
}()
} else {
// Stream invocation is only allowed when the method has only one return value
// We allow no channel return values, because a client can receive as stream with only one item
Expand Down Expand Up @@ -273,7 +268,7 @@ func (l *loop) handleCompletionMessage(message completionMessage) error {
} else if l.invokeClient.handlesInvocationID(message.InvocationID) {
err = l.invokeClient.receiveCompletionItem(message)
} else {
err = fmt.Errorf("unkown invocationID %v", message.InvocationID)
err = fmt.Errorf("unknown invocationID %v", message.InvocationID)
}
if err != nil {
_ = l.info.Log(evt, msgRecv, "error", err, msg, fmtMsg(message), react, "close connection")
Expand Down Expand Up @@ -333,9 +328,9 @@ func (l *loop) recoverInvocationPanic(invocation invocationMessage) {
}

func buildMethodArguments(method reflect.Value, invocation invocationMessage,
streamClient *streamClient, protocol hubProtocol) (arguments []reflect.Value, clientStreaming bool, err error) {
streamClient *streamClient, protocol hubProtocol) (arguments []reflect.Value, err error) {
if len(invocation.StreamIds)+len(invocation.Arguments) != method.Type().NumIn() {
return nil, false, fmt.Errorf("parameter mismatch calling method %v", invocation.Target)
return nil, fmt.Errorf("parameter mismatch calling method %v", invocation.Target)
}
arguments = make([]reflect.Value, method.Type().NumIn())
chanCount := 0
Expand All @@ -344,7 +339,7 @@ func buildMethodArguments(method reflect.Value, invocation invocationMessage,
// Is it a channel for client streaming?
if arg, clientStreaming, err := streamClient.buildChannelArgument(invocation, t, chanCount); err != nil {
// it is, but channel count in invocation and method mismatch
return nil, false, err
return nil, err
} else if clientStreaming {
// it is
chanCount++
Expand All @@ -353,15 +348,15 @@ func buildMethodArguments(method reflect.Value, invocation invocationMessage,
// it is not, so do the normal thing
arg := reflect.New(t)
if err := protocol.UnmarshalArgument(invocation.Arguments[i-chanCount], arg.Interface()); err != nil {
return arguments, chanCount > 0, err
return arguments, err
}
arguments[i] = arg.Elem()
}
}
if len(invocation.StreamIds) != chanCount {
return arguments, chanCount > 0, fmt.Errorf("to many StreamIds for channel parameters of method %v", invocation.Target)
return arguments, fmt.Errorf("to many StreamIds for channel parameters of method %v", invocation.Target)
}
return arguments, chanCount > 0, nil
return arguments, nil
}

func getMethod(target interface{}, name string) (reflect.Value, bool) {
Expand Down
3 changes: 1 addition & 2 deletions serversseconnection.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"fmt"
"io"
"io/ioutil"
"net/http"
"strings"
"sync"
Expand Down Expand Up @@ -51,7 +50,7 @@ func (s *serverSSEConnection) consumeRequest(request *http.Request) int {
defer func() {
_ = request.Body.Close()
}()
body, err := ioutil.ReadAll(request.Body)
body, err := io.ReadAll(request.Body)
if err != nil {
return http.StatusBadRequest // 400
} else if _, err := s.postWriter.Write(body); err != nil {
Expand Down
Loading

0 comments on commit a16987d

Please sign in to comment.