diff --git a/exchanges/bybit/bybit_types.go b/exchanges/bybit/bybit_types.go index 16114557684..557eddaf1d0 100644 --- a/exchanges/bybit/bybit_types.go +++ b/exchanges/bybit/bybit_types.go @@ -8,6 +8,7 @@ import ( "github.com/gofrs/uuid" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" + "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/types" ) @@ -33,10 +34,11 @@ type Authenticate struct { // SubscriptionArgument represents a subscription arguments. type SubscriptionArgument struct { - auth bool `json:"-"` - RequestID string `json:"req_id"` - Operation string `json:"op"` - Arguments []string `json:"args"` + auth bool `json:"-"` + RequestID string `json:"req_id"` + Operation string `json:"op"` + Arguments []string `json:"args"` + associatedSubs subscription.List `json:"-"` // Used to store associated subscriptions } // Fee holds fee information diff --git a/exchanges/bybit/bybit_websocket.go b/exchanges/bybit/bybit_websocket.go index 9d1b84268c1..9db3934fece 100644 --- a/exchanges/bybit/bybit_websocket.go +++ b/exchanges/bybit/bybit_websocket.go @@ -167,28 +167,30 @@ func (by *Bybit) handleSubscriptions(operation string, subs subscription.List) ( if err != nil { return } - chans := []string{} - authChans := []string{} + var chans subscription.List + var authChans subscription.List for _, s := range subs { if s.Authenticated { - authChans = append(authChans, s.QualifiedChannel) + authChans = append(authChans, s) } else { - chans = append(chans, s.QualifiedChannel) + chans = append(chans, s) } } for _, b := range common.Batch(chans, 10) { args = append(args, SubscriptionArgument{ - Operation: operation, - RequestID: strconv.FormatInt(by.Websocket.Conn.GenerateMessageID(false), 10), - Arguments: b, + Operation: operation, + RequestID: strconv.FormatInt(by.Websocket.Conn.GenerateMessageID(false), 10), + Arguments: b.QualifiedChannels(), + associatedSubs: b, }) } if len(authChans) != 0 { args = append(args, SubscriptionArgument{ - auth: true, - Operation: operation, - RequestID: strconv.FormatInt(by.Websocket.Conn.GenerateMessageID(false), 10), - Arguments: authChans, + auth: true, + Operation: operation, + RequestID: strconv.FormatInt(by.Websocket.Conn.GenerateMessageID(false), 10), + Arguments: authChans.QualifiedChannels(), + associatedSubs: authChans, }) } return @@ -225,6 +227,22 @@ func (by *Bybit) handleSpotSubscription(operation string, channelsToSubscribe su if !resp.Success { return fmt.Errorf("%s with request ID %s msg: %s", resp.Operation, resp.RequestID, resp.RetMsg) } + + var conn stream.Connection + if payloads[a].auth { + conn = by.Websocket.AuthConn + } else { + conn = by.Websocket.Conn + } + + if operation == "unsubscribe" { + err = by.Websocket.RemoveSubscriptions(conn, payloads[a].associatedSubs...) + } else { + err = by.Websocket.AddSubscriptions(conn, payloads[a].associatedSubs...) + } + if err != nil { + return err + } } return nil } diff --git a/exchanges/deribit/deribit_websocket.go b/exchanges/deribit/deribit_websocket.go index 21e7088bb7b..125c3d9b6ff 100644 --- a/exchanges/deribit/deribit_websocket.go +++ b/exchanges/deribit/deribit_websocket.go @@ -836,7 +836,14 @@ func (d *Deribit) handleSubscription(method string, subs subscription.List) erro err = common.AppendError(err, errors.New(s.String())) } } - return err + if err != nil { + return err + } + + if method == "unsubscribe" { + return d.Websocket.RemoveSubscriptions(d.Websocket.Conn, subs...) + } + return d.Websocket.AddSubscriptions(d.Websocket.Conn, subs...) } func getValidatedCurrencyCode(pair currency.Pair) string { diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 737c0eadc7f..59103912ebb 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -56,6 +56,8 @@ var ( errReadMessageErrorsNil = errors.New("read message errors is nil") errWebsocketSubscriptionsGeneratorUnset = errors.New("websocket subscriptions generator function needs to be set") errSubscriptionsExceedsLimit = errors.New("subscriptions exceeds limit") + errSubscriptionsNotAdded = errors.New("subscriptions not added") + errSubscriptionsNotRemoved = errors.New("subscriptions not removed") errInvalidMaxSubscriptions = errors.New("max subscriptions cannot be less than 0") errSameProxyAddress = errors.New("cannot set proxy address to the same address") errNoConnectFunc = errors.New("websocket connect func not set") @@ -372,6 +374,10 @@ func (w *Websocket) connect() error { if err := w.SubscribeToChannels(nil, subs); err != nil { return err } + + if w.subscriptions.Len() != len(subs) { + return fmt.Errorf("%s %w expecting %d subscribed", w.exchangeName, errSubscriptionsNotAdded, len(subs)) + } } return nil } @@ -455,6 +461,11 @@ func (w *Websocket) connect() error { break } + if len(subs) != 0 && w.connectionManager[i].Subscriptions.Len() != len(subs) { + multiConnectFatalError = fmt.Errorf("%v %w expecting %d subscribed %v", w.exchangeName, errSubscriptionsNotAdded, len(subs), subs) + break + } + if w.verbose { log.Debugf(log.WebsocketMgr, "%s websocket: [conn:%d] [URL:%s] connected. [Subscribed: %d]", w.exchangeName, @@ -625,14 +636,7 @@ func (w *Websocket) FlushChannels() error { if err != nil { return err } - subs, unsubs := w.GetChannelDifference(nil, newSubs) - if err := w.UnsubscribeChannels(nil, unsubs); err != nil { - return err - } - if len(subs) == 0 { - return nil - } - return w.SubscribeToChannels(nil, subs) + return w.updateChannelSubscriptions(nil, w.subscriptions, newSubs) } for x := range w.connectionManager { @@ -658,17 +662,9 @@ func (w *Websocket) FlushChannels() error { w.connectionManager[x].Connection = conn } - subs, unsubs := w.GetChannelDifference(w.connectionManager[x].Connection, newSubs) - - if len(unsubs) != 0 { - if err := w.UnsubscribeChannels(w.connectionManager[x].Connection, unsubs); err != nil { - return err - } - } - if len(subs) != 0 { - if err := w.SubscribeToChannels(w.connectionManager[x].Connection, subs); err != nil { - return err - } + err = w.updateChannelSubscriptions(w.connectionManager[x].Connection, w.connectionManager[x].Subscriptions, newSubs) + if err != nil { + return err } // If there are no subscriptions to subscribe to, close the connection as it is no longer needed. @@ -683,6 +679,31 @@ func (w *Websocket) FlushChannels() error { return nil } +// updateChannelSubscriptions subscribes or unsubscribes from channels and checks that the correct number of channels +// have been subscribed to or unsubscribed from. +func (w *Websocket) updateChannelSubscriptions(c Connection, store *subscription.Store, incoming subscription.List) error { + subs, unsubs := w.GetChannelDifference(c, incoming) + if len(unsubs) != 0 { + prevState := store.Len() + if err := w.UnsubscribeChannels(c, unsubs); err != nil { + return err + } + if diff := prevState - store.Len(); diff != len(unsubs) { + return fmt.Errorf("%v %w expected %d unsubscribed", w.exchangeName, errSubscriptionsNotRemoved, len(unsubs)) + } + } + if len(subs) != 0 { + prevState := store.Len() + if err := w.SubscribeToChannels(c, subs); err != nil { + return err + } + if diff := store.Len() - prevState; diff != len(subs) { + return fmt.Errorf("%v %w expected %d subscribed", w.exchangeName, errSubscriptionsNotAdded, len(subs)) + } + } + return nil +} + func (w *Websocket) setState(s uint32) { w.state.Store(s) } diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index b6f3a762404..c64640738c1 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -63,31 +63,33 @@ type testSubKey struct { Mood string } -var defaultSetup = &WebsocketSetup{ - ExchangeConfig: &config.Exchange{ - Features: &config.FeaturesConfig{ - Enabled: config.FeaturesEnabledConfig{Websocket: true}, +func newDefaultSetup() *WebsocketSetup { + return &WebsocketSetup{ + ExchangeConfig: &config.Exchange{ + Features: &config.FeaturesConfig{ + Enabled: config.FeaturesEnabledConfig{Websocket: true}, + }, + API: config.APIConfig{ + AuthenticatedWebsocketSupport: true, + }, + WebsocketTrafficTimeout: time.Second * 5, + Name: "GTX", }, - API: config.APIConfig{ - AuthenticatedWebsocketSupport: true, + DefaultURL: "testDefaultURL", + RunningURL: "wss://testRunningURL", + Connector: func() error { return nil }, + Subscriber: func(subscription.List) error { return nil }, + Unsubscriber: func(subscription.List) error { return nil }, + GenerateSubscriptions: func() (subscription.List, error) { + return subscription.List{ + {Channel: "TestSub"}, + {Channel: "TestSub2", Key: "purple"}, + {Channel: "TestSub3", Key: testSubKey{"mauve"}}, + {Channel: "TestSub4", Key: 42}, + }, nil }, - WebsocketTrafficTimeout: time.Second * 5, - Name: "GTX", - }, - DefaultURL: "testDefaultURL", - RunningURL: "wss://testRunningURL", - Connector: func() error { return nil }, - Subscriber: func(subscription.List) error { return nil }, - Unsubscriber: func(subscription.List) error { return nil }, - GenerateSubscriptions: func() (subscription.List, error) { - return subscription.List{ - {Channel: "TestSub"}, - {Channel: "TestSub2", Key: "purple"}, - {Channel: "TestSub3", Key: testSubKey{"mauve"}}, - {Channel: "TestSub4", Key: 42}, - }, nil - }, - Features: &protocol.Features{Subscribe: true, Unsubscribe: true}, + Features: &protocol.Features{Subscribe: true, Unsubscribe: true}, + } } func TestSetup(t *testing.T) { @@ -186,13 +188,23 @@ func TestConnectionMessageErrors(t *testing.T) { assert.ErrorIs(t, err, errDastardlyReason, "Connect should error correctly") ws := NewWebsocket() - err = ws.Setup(defaultSetup) + err = ws.Setup(newDefaultSetup()) require.NoError(t, err, "Setup must not error") ws.trafficTimeout = time.Minute ws.connector = connect - err = ws.Connect() - require.NoError(t, err, "Connect must not error") + require.ErrorIs(t, ws.Connect(), errSubscriptionsNotAdded) + require.NoError(t, ws.Shutdown()) + + ws.Subscriber = func(subs subscription.List) error { + for _, sub := range subs { + if err := ws.subscriptions.Add(sub); err != nil { + return err + } + } + return nil + } + require.NoError(t, ws.Connect(), "Connect must not error") checkToRoutineResult := func(t *testing.T) { t.Helper() @@ -279,6 +291,13 @@ func TestConnectionMessageErrors(t *testing.T) { return nil } err = ws.Connect() + require.ErrorIs(t, err, errSubscriptionsNotAdded) + + ws.connectionManager[0].Subscriptions = subscription.NewStore() + ws.connectionManager[0].Setup.Subscriber = func(context.Context, Connection, subscription.List) error { + return ws.connectionManager[0].Subscriptions.Add(&subscription.Subscription{Channel: "test"}) + } + err = ws.Connect() require.NoError(t, err) err = ws.connectionManager[0].Connection.SendRawMessage(context.Background(), request.Unset, websocket.TextMessage, []byte("test")) @@ -297,6 +316,7 @@ func TestWebsocket(t *testing.T) { assert.ErrorContains(t, err, "invalid URI for request", "SetProxyAddress should error correctly") ws.setEnabled(true) + defaultSetup := newDefaultSetup() err = ws.Setup(defaultSetup) // Sets to enabled again require.NoError(t, err, "Setup may not error") @@ -340,8 +360,19 @@ func TestWebsocket(t *testing.T) { assert.NoError(t, ws.Shutdown()) ws.connector = func() error { return nil } - err = ws.Connect() - assert.NoError(t, err, "Connect should not error") + + require.ErrorIs(t, ws.Connect(), errSubscriptionsNotAdded) + require.NoError(t, ws.Shutdown()) + + ws.Subscriber = func(subs subscription.List) error { + for _, sub := range subs { + if err := ws.subscriptions.Add(sub); err != nil { + return err + } + } + return nil + } + assert.NoError(t, ws.Connect(), "Connect should not error") ws.defaultURL = "ws://demos.kaazing.com/echo" ws.defaultURLAuth = "ws://demos.kaazing.com/echo" @@ -407,7 +438,7 @@ func currySimpleUnsubConn(w *Websocket) func(context.Context, Connection, subscr func TestSubscribeUnsubscribe(t *testing.T) { t.Parallel() ws := NewWebsocket() - assert.NoError(t, ws.Setup(defaultSetup), "WS Setup should not error") + assert.NoError(t, ws.Setup(newDefaultSetup()), "WS Setup should not error") ws.Subscriber = currySimpleSub(ws) ws.Unsubscriber = currySimpleUnsub(ws) @@ -447,9 +478,9 @@ func TestSubscribeUnsubscribe(t *testing.T) { assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly when list contains a nil subscription") multi := NewWebsocket() - set := *defaultSetup + set := newDefaultSetup() set.UseMultiConnectionManagement = true - assert.NoError(t, multi.Setup(&set)) + assert.NoError(t, multi.Setup(set)) amazingCandidate := &ConnectionSetup{ URL: "AMAZING", @@ -514,12 +545,12 @@ func TestResubscribe(t *testing.T) { t.Parallel() ws := NewWebsocket() - wackedOutSetup := *defaultSetup + wackedOutSetup := newDefaultSetup() wackedOutSetup.MaxWebsocketSubscriptionsPerConnection = -1 - err := ws.Setup(&wackedOutSetup) + err := ws.Setup(wackedOutSetup) assert.ErrorIs(t, err, errInvalidMaxSubscriptions, "Invalid MaxWebsocketSubscriptionsPerConnection should error") - err = ws.Setup(defaultSetup) + err = ws.Setup(newDefaultSetup()) assert.NoError(t, err, "WS Setup should not error") ws.Subscriber = currySimpleSub(ws) @@ -1111,16 +1142,37 @@ func TestFlushChannels(t *testing.T) { // Disable pair and flush system newgen.EnabledPairs = []currency.Pair{currency.NewPair(currency.BTC, currency.AUD)} w.GenerateSubs = func() (subscription.List, error) { return subscription.List{{Channel: "test"}}, nil } - err = w.FlushChannels() - require.NoError(t, err, "Flush Channels must not error") + + require.ErrorIs(t, w.FlushChannels(), errSubscriptionsNotAdded, "FlushChannels should error correctly on no subscriptions added") + + w.Subscriber = func(subs subscription.List) error { + for _, sub := range subs { + if err := w.subscriptions.Add(sub); err != nil { + return err + } + } + return nil + } + + require.NoError(t, w.FlushChannels(), "Flush Channels must not error") w.GenerateSubs = func() (subscription.List, error) { return nil, errDastardlyReason } // error on generateSubs err = w.FlushChannels() // error on full subscribeToChannels assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly on GenerateSubs") w.GenerateSubs = func() (subscription.List, error) { return nil, nil } // No subs to sub - err = w.FlushChannels() // No subs to sub - assert.NoError(t, err, "Flush Channels should not error") + + require.ErrorIs(t, w.FlushChannels(), errSubscriptionsNotRemoved) + + w.Unsubscriber = func(subs subscription.List) error { + for _, sub := range subs { + if err := w.subscriptions.Remove(sub); err != nil { + return err + } + } + return nil + } + assert.NoError(t, w.FlushChannels(), "Flush Channels should not error") w.GenerateSubs = newgen.generateSubs subs, err := w.GenerateSubs() @@ -1156,21 +1208,24 @@ func TestFlushChannels(t *testing.T) { mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) defer mock.Close() + w.subscriptions = subscription.NewStore() + amazingCandidate := &ConnectionSetup{ URL: "ws" + mock.URL[len("http"):] + "/ws", Connector: func(ctx context.Context, conn Connection) error { return conn.DialContext(ctx, websocket.DefaultDialer, nil) }, GenerateSubscriptions: newgen.generateSubs, - Subscriber: func(ctx context.Context, c Connection, s subscription.List) error { - return currySimpleSubConn(w)(ctx, c, s) - }, - Unsubscriber: func(ctx context.Context, c Connection, s subscription.List) error { - return currySimpleUnsubConn(w)(ctx, c, s) - }, - Handler: func(context.Context, []byte) error { return nil }, + Subscriber: func(context.Context, Connection, subscription.List) error { return nil }, + Unsubscriber: func(context.Context, Connection, subscription.List) error { return nil }, + Handler: func(context.Context, []byte) error { return nil }, } require.NoError(t, w.SetupNewConnection(amazingCandidate)) + require.ErrorIs(t, w.FlushChannels(), errSubscriptionsNotAdded, "Must error when no subscriptions are added to the subscription store") + + w.connectionManager[0].Setup.Subscriber = func(ctx context.Context, c Connection, s subscription.List) error { + return currySimpleSubConn(w)(ctx, c, s) + } require.NoError(t, w.FlushChannels(), "FlushChannels must not error") // Forces full connection cycle (shutdown, connect, subscribe). This will also start monitoring routines. @@ -1181,6 +1236,11 @@ func TestFlushChannels(t *testing.T) { // of the connection from management. w.features.Subscribe = true w.connectionManager[0].Setup.GenerateSubscriptions = func() (subscription.List, error) { return nil, nil } + require.ErrorIs(t, w.FlushChannels(), errSubscriptionsNotRemoved, "Must error when no subscriptions are removed from subscription store") + + w.connectionManager[0].Setup.Unsubscriber = func(ctx context.Context, c Connection, s subscription.List) error { + return currySimpleUnsubConn(w)(ctx, c, s) + } require.NoError(t, w.FlushChannels(), "FlushChannels must not error") } @@ -1224,7 +1284,7 @@ func TestSetupNewConnection(t *testing.T) { web := NewWebsocket() - err = web.Setup(defaultSetup) + err = web.Setup(newDefaultSetup()) assert.NoError(t, err, "Setup should not error") err = web.SetupNewConnection(&ConnectionSetup{URL: "urlstring"}) @@ -1235,9 +1295,9 @@ func TestSetupNewConnection(t *testing.T) { // Test connection candidates for multi connection tracking. multi := NewWebsocket() - set := *defaultSetup + set := newDefaultSetup() set.UseMultiConnectionManagement = true - require.NoError(t, multi.Setup(&set)) + require.NoError(t, multi.Setup(set)) err = multi.SetupNewConnection(nil) require.ErrorIs(t, err, errExchangeConfigEmpty)