diff --git a/Makefile b/Makefile index def073a..3bf4365 100644 --- a/Makefile +++ b/Makefile @@ -173,8 +173,10 @@ integration-test-container-dev: build-image-dev start-deps-container-dev test-db .PHONY: mocks mocks: - $(MOCKGENERATE) -source=interfaces/client.go -destination=mocks/firebase/client.go + $(MOCKGENERATE) -source=interfaces/client.go -destination=mocks/interfaces/client.go $(MOCKGENERATE) -source=interfaces/apns.go -destination=mocks/interfaces/apns.go $(MOCKGENERATE) -source=interfaces/statsd.go -destination=mocks/interfaces/statsd.go + $(MOCKGENERATE) -source=interfaces/stats_reporter.go -destination=mocks/interfaces/stats_reporter.go $(MOCKGENERATE) -source=interfaces/feedback_reporter.go -destination=mocks/interfaces/feedback_reporter.go $(MOCKGENERATE) -source=interfaces/message_handler.go -destination=mocks/interfaces/message_handler.go + $(MOCKGENERATE) -source=interfaces/rate_limiter.go -destination=mocks/interfaces/rate_limiter.go diff --git a/e2e/fcm_e2e_test.go b/e2e/fcm_e2e_test.go index c7cf7bf..0440892 100644 --- a/e2e/fcm_e2e_test.go +++ b/e2e/fcm_e2e_test.go @@ -15,9 +15,8 @@ import ( "github.com/stretchr/testify/suite" "github.com/topfreegames/pusher/config" "github.com/topfreegames/pusher/extensions" - "github.com/topfreegames/pusher/extensions/handler" + "github.com/topfreegames/pusher/extensions/firebase" "github.com/topfreegames/pusher/interfaces" - firebaseMock "github.com/topfreegames/pusher/mocks/firebase" mocks "github.com/topfreegames/pusher/mocks/interfaces" "github.com/topfreegames/pusher/pusher" "go.uber.org/mock/gomock" @@ -46,7 +45,7 @@ func (s *FcmE2ETestSuite) SetupSuite() { s.vConfig = v } -func (s *FcmE2ETestSuite) setupFcmPusher(appName string) (*firebaseMock.MockPushClient, *mocks.MockStatsDClient) { +func (s *FcmE2ETestSuite) setupFcmPusher(appName string) (*mocks.MockPushClient, *mocks.MockStatsDClient) { ctrl := gomock.NewController(s.T()) statsdClientMock := mocks.NewMockStatsDClient(ctrl) @@ -74,14 +73,15 @@ func (s *FcmE2ETestSuite) setupFcmPusher(appName string) (*firebaseMock.MockPush limit := s.vConfig.GetInt("gcm.rateLimit.rpm") rateLimiter := extensions.NewRateLimiter(limit, s.vConfig, []interfaces.StatsReporter{statsReport}, logger) - pushClient := firebaseMock.NewMockPushClient(ctrl) + pushClient := mocks.NewMockPushClient(ctrl) gcmPusher.MessageHandler = map[string]interfaces.MessageHandler{ - appName: handler.NewMessageHandler( + appName: firebase.NewMessageHandler( appName, pushClient, []interfaces.FeedbackReporter{}, []interfaces.StatsReporter{statsReport}, rateLimiter, + nil, logger, s.config.GCM.ConcurrentWorkers, ), diff --git a/extensions/apns_message_handler.go b/extensions/apns/apns_message_handler.go similarity index 72% rename from extensions/apns_message_handler.go rename to extensions/apns/apns_message_handler.go index c90d714..856fb16 100644 --- a/extensions/apns_message_handler.go +++ b/extensions/apns/apns_message_handler.go @@ -20,13 +20,14 @@ * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ -package extensions +package apns import ( "context" "encoding/json" "errors" "fmt" + "github.com/topfreegames/pusher/extensions" "os" "sync" "time" @@ -40,8 +41,6 @@ import ( "github.com/topfreegames/pusher/structs" ) -var apnsResMutex sync.Mutex - // pusherAPNSKafkaMessage is the notification format received in Kafka messages. type pusherAPNSKafkaMessage struct { ApnsID string @@ -72,10 +71,8 @@ type APNSMessageHandler struct { sentMessages int64 ignoredMessages int64 successesReceived int64 - requestsHeap *TimeoutHeap CacheCleaningInterval int IsProduction bool - consumptionManager interfaces.ConsumptionManager retryInterval time.Duration maxRetryAttempts uint rateLimiter interfaces.RateLimiter @@ -93,31 +90,22 @@ func NewAPNSMessageHandler( statsReporters []interfaces.StatsReporter, feedbackReporters []interfaces.FeedbackReporter, pushQueue interfaces.APNSPushQueue, - consumptionManager interfaces.ConsumptionManager, rateLimiter interfaces.RateLimiter, ) (*APNSMessageHandler, error) { a := &APNSMessageHandler{ - authKeyPath: authKeyPath, - keyID: keyID, - teamID: teamID, - ApnsTopic: topic, - appName: appName, - Config: config, - failuresReceived: 0, - feedbackReporters: feedbackReporters, - IsProduction: isProduction, - Logger: logger, - pendingMessagesWG: pendingMessagesWG, - ignoredMessages: 0, - inFlightNotificationsMapLock: &sync.Mutex{}, - responsesReceived: 0, - sentMessages: 0, - StatsReporters: statsReporters, - successesReceived: 0, - requestsHeap: NewTimeoutHeap(config), - PushQueue: pushQueue, - consumptionManager: consumptionManager, - rateLimiter: rateLimiter, + authKeyPath: authKeyPath, + keyID: keyID, + teamID: teamID, + ApnsTopic: topic, + appName: appName, + Config: config, + feedbackReporters: feedbackReporters, + IsProduction: isProduction, + Logger: logger, + pendingMessagesWG: pendingMessagesWG, + StatsReporters: statsReporters, + PushQueue: pushQueue, + rateLimiter: rateLimiter, } if a.Logger != nil { @@ -196,35 +184,26 @@ func (a *APNSMessageHandler) HandleMessages(ctx context.Context, message interfa if err != nil { l.WithError(err).Error("error parsing kafka message") a.waitGroupDone() - apnsResMutex.Lock() - a.ignoredMessages++ - apnsResMutex.Unlock() return } l = l.WithField("notification", parsedNotification) + + allowed := a.rateLimiter.Allow(ctx, parsedNotification.DeviceToken, a.appName, "apns") + if !allowed { + extensions.StatsReporterNotificationRateLimitReached(a.StatsReporters, a.appName, "apns") + a.waitGroupDone() + return + } + n, err := a.buildAndValidateNotification(parsedNotification) if err != nil { l.WithError(err).Error("notification is invalid") a.waitGroupDone() - apnsResMutex.Lock() - a.ignoredMessages++ - apnsResMutex.Unlock() return - } - allowed := a.rateLimiter.Allow(ctx, parsedNotification.DeviceToken, a.appName, "apns") - if !allowed { - statsReporterNotificationRateLimitReached(a.StatsReporters, a.appName, "apns") - l.WithField("message", message).Warn("rate limit reached") - return } - a.sendNotification(n) - statsReporterHandleNotificationSent(a.StatsReporters, a.appName, "apns") - - apnsResMutex.Lock() - a.sentMessages++ - apnsResMutex.Unlock() + extensions.StatsReporterHandleNotificationSent(a.StatsReporters, a.appName, "apns") } func (a *APNSMessageHandler) parseKafkaMessage(message interfaces.KafkaMessage) (*pusherAPNSKafkaMessage, error) { @@ -240,8 +219,8 @@ func (a *APNSMessageHandler) parseKafkaMessage(message interfaces.KafkaMessage) } notification.Metadata["game"] = a.appName notification.Metadata["deviceToken"] = notification.DeviceToken - hostname, err := os.Hostname() + hostname, err := os.Hostname() if err != nil { a.Logger.WithError(err).Error("error retrieving hostname") } else { @@ -253,7 +232,7 @@ func (a *APNSMessageHandler) parseKafkaMessage(message interfaces.KafkaMessage) } func (a *APNSMessageHandler) buildAndValidateNotification(notification *pusherAPNSKafkaMessage) (*structs.ApnsNotification, error) { - if notification.PushExpiry > 0 && notification.PushExpiry < MakeTimestamp() { + if notification.PushExpiry > 0 && notification.PushExpiry < extensions.MakeTimestamp() { return nil, errors.New("push message has expired") } @@ -277,7 +256,7 @@ func (a *APNSMessageHandler) buildAndValidateNotification(notification *pusherAP func (a *APNSMessageHandler) sendNotification(notification *structs.ApnsNotification) { before := time.Now() - defer statsReporterReportSendNotificationLatency(a.StatsReporters, time.Since(before), a.appName, "apns", "client", "apns") + defer extensions.StatsReporterReportSendNotificationLatency(a.StatsReporters, time.Since(before), a.appName, "apns", "client", "apns") notification.SendAttempts += 1 a.PushQueue.Push(notification) @@ -311,11 +290,7 @@ func (a *APNSMessageHandler) handleAPNSResponse(responseWithMetadata *structs.Re } delete(responseWithMetadata.Metadata, "timestamp") - apnsResMutex.Lock() - a.responsesReceived++ - apnsResMutex.Unlock() - - parsedTopic := ParsedTopic{ + parsedTopic := extensions.ParsedTopic{ Game: a.appName, Platform: "apns", } @@ -323,94 +298,30 @@ func (a *APNSMessageHandler) handleAPNSResponse(responseWithMetadata *structs.Re a.waitGroupDone() if responseWithMetadata.Reason == "" { - sendFeedbackErr := sendToFeedbackReporters(a.feedbackReporters, responseWithMetadata, parsedTopic) - if sendFeedbackErr != nil { - l.WithError(sendFeedbackErr).Error("error sending feedback to reporter") - } - apnsResMutex.Lock() - a.successesReceived++ - apnsResMutex.Unlock() - statsReporterHandleNotificationSuccess(a.StatsReporters, a.appName, "apns") + extensions.StatsReporterHandleNotificationSuccess(a.StatsReporters, a.appName, "apns") return nil } - apnsResMutex.Lock() - a.failuresReceived++ - apnsResMutex.Unlock() - pErr := pusher_errors.NewPushError(a.mapErrorReason(responseWithMetadata.Reason), responseWithMetadata.Reason) responseWithMetadata.Err = pErr - statsReporterHandleNotificationFailure(a.StatsReporters, a.appName, "apns", pErr) - err := pErr + l.Info("notification failed") + extensions.StatsReporterHandleNotificationFailure(a.StatsReporters, a.appName, "apns", pErr) + switch responseWithMetadata.Reason { case apns2.ReasonBadDeviceToken, apns2.ReasonUnregistered, apns2.ReasonTopicDisallowed, apns2.ReasonDeviceTokenNotForTopic: // https://developer.apple.com/library/content/documentation/NetworkingInternet/Conceptual/RemoteNotificationsPG/CommunicatingwithAPNs.html - l.WithFields(log.Fields{ - "category": "TokenError", - log.ErrorKey: responseWithMetadata.Reason, - }).Debug("received an error") - if responseWithMetadata.Metadata != nil { - responseWithMetadata.Metadata["deleteToken"] = true + if responseWithMetadata.Metadata == nil { + responseWithMetadata.Metadata = map[string]interface{}{} } - case apns2.ReasonBadCertificate, apns2.ReasonBadCertificateEnvironment, apns2.ReasonForbidden: - l.WithFields(log.Fields{ - "category": "CertificateError", - log.ErrorKey: responseWithMetadata.Reason, - }).Debug("received an error") - case apns2.ReasonExpiredProviderToken, apns2.ReasonInvalidProviderToken, apns2.ReasonMissingProviderToken: - l.WithFields(log.Fields{ - "category": "ProviderTokenError", - log.ErrorKey: responseWithMetadata.Reason, - }).Debug("received an error") - case apns2.ReasonMissingTopic: - l.WithFields(log.Fields{ - "category": "TopicError", - log.ErrorKey: responseWithMetadata.Reason, - }).Debug("received an error") - case apns2.ReasonIdleTimeout, apns2.ReasonShutdown, apns2.ReasonInternalServerError, apns2.ReasonServiceUnavailable: - l.WithFields(log.Fields{ - "category": "AppleError", - log.ErrorKey: responseWithMetadata.Reason, - }).Debug("received an error") - default: - l.WithFields(log.Fields{ - "category": "DefaultError", - log.ErrorKey: responseWithMetadata.Reason, - }).Debug("received an error") + responseWithMetadata.Metadata["deleteToken"] = true } - sendFeedbackErr := sendToFeedbackReporters(a.feedbackReporters, responseWithMetadata, parsedTopic) + + sendFeedbackErr := extensions.SendToFeedbackReporters(a.feedbackReporters, responseWithMetadata, parsedTopic) if sendFeedbackErr != nil { l.WithError(sendFeedbackErr).Error("error sending feedback to reporter") } - return err -} -// LogStats from time to time. -func (a *APNSMessageHandler) LogStats() { - l := a.Logger.WithFields(log.Fields{ - "method": "apnsMessageHandler.logStats", - "interval(ns)": a.LogStatsInterval, - }) - - ticker := time.NewTicker(a.LogStatsInterval) - for range ticker.C { - apnsResMutex.Lock() - if a.sentMessages > 0 || a.responsesReceived > 0 || a.ignoredMessages > 0 || a.successesReceived > 0 || a.failuresReceived > 0 { - l.WithFields(log.Fields{ - "sentMessages": a.sentMessages, - "ignoredMessages": a.ignoredMessages, - "responsesReceived": a.responsesReceived, - "successesReceived": a.successesReceived, - "failuresReceived": a.failuresReceived, - }).Info("flushing stats") - a.sentMessages = 0 - a.responsesReceived = 0 - a.ignoredMessages = 0 - a.successesReceived = 0 - a.failuresReceived = 0 - } - apnsResMutex.Unlock() - } + return nil } func (a *APNSMessageHandler) mapErrorReason(reason string) string { diff --git a/extensions/apns/apns_message_handler_test.go b/extensions/apns/apns_message_handler_test.go new file mode 100644 index 0000000..8853853 --- /dev/null +++ b/extensions/apns/apns_message_handler_test.go @@ -0,0 +1,461 @@ +/* + * Copyright (c) 2016 TFG Co + * Author: TFG Co + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of + * this software and associated documentation files (the "Software"), to deal in + * the Software without restriction, including without limitation the rights to + * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of + * the Software, and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +package apns + +import ( + "context" + "encoding/json" + "fmt" + uuid "github.com/satori/go.uuid" + "github.com/sideshow/apns2" + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/topfreegames/pusher/errors" + mock_interfaces "github.com/topfreegames/pusher/mocks/interfaces" + "go.uber.org/mock/gomock" + "sync" + "testing" + "time" + + "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus/hooks/test" + "github.com/topfreegames/pusher/interfaces" + "github.com/topfreegames/pusher/structs" + "github.com/topfreegames/pusher/util" +) + +type ApnsMessageHandlerTestSuite struct { + suite.Suite + + config *viper.Viper + authKeyPath string + keyID string + teamID string + topic string + appName string + isProduction bool + logger *logrus.Logger + + mockApnsPushQueue *mock_interfaces.MockAPNSPushQueue + mockStatsReporter *mock_interfaces.MockStatsReporter + mockFeedbackReporter *mock_interfaces.MockFeedbackReporter + mockRateLimiter *mock_interfaces.MockRateLimiter + waitGroup *sync.WaitGroup + handler *APNSMessageHandler +} + +func TestApnsMessageHandlerTestSuite(t *testing.T) { + suite.Run(t, new(ApnsMessageHandlerTestSuite)) +} + +func (s *ApnsMessageHandlerTestSuite) SetupSuite() { + cfg, err := util.NewViperWithConfigFile("../../config/test.yaml") + require.NoError(s.T(), err) + + s.config = cfg + s.authKeyPath = "../tls/authkey.p8" + s.keyID = "ABC123DEFG" + s.teamID = "DEF123GHIJ" + s.topic = "com.game.test" + s.appName = "game" + s.isProduction = false + s.logger, _ = test.NewNullLogger() + s.logger.Level = logrus.DebugLevel +} + +func (s *ApnsMessageHandlerTestSuite) SetupTest() { + ctrl := gomock.NewController(s.T()) + + mockFeedbackReporter := mock_interfaces.NewMockFeedbackReporter(ctrl) + mockStatsReporter := mock_interfaces.NewMockStatsReporter(ctrl) + statsClients := []interfaces.StatsReporter{mockStatsReporter} + feedbackClients := []interfaces.FeedbackReporter{mockFeedbackReporter} + + mockPushQueue := mock_interfaces.NewMockAPNSPushQueue(ctrl) + mockRateLimiter := mock_interfaces.NewMockRateLimiter(ctrl) + wg := &sync.WaitGroup{} + + handler, err := NewAPNSMessageHandler( + s.authKeyPath, + s.keyID, + s.teamID, + s.topic, + s.appName, + s.isProduction, + s.config, + s.logger, + wg, + statsClients, + feedbackClients, + mockPushQueue, + mockRateLimiter, + ) + require.NoError(s.T(), err) + + s.handler = handler + s.mockApnsPushQueue = mockPushQueue + s.mockStatsReporter = mockStatsReporter + s.mockFeedbackReporter = mockFeedbackReporter + s.waitGroup = wg + s.mockRateLimiter = mockRateLimiter +} + +func (s *ApnsMessageHandlerTestSuite) SetupSubTest() { + s.SetupTest() +} + +func (s *ApnsMessageHandlerTestSuite) TestHandleMessage() { + s.Run("should fail if invalid kafka message format", func() { + msg := interfaces.KafkaMessage{ + Topic: "push-game_apns", + Value: []byte(`not json`), + } + + s.waitGroup.Add(1) + s.handler.HandleMessages(context.Background(), msg) + + waitWG(s.T(), s.waitGroup) + }) + + s.Run("should fail if notification expired", func() { + expiration := time.Now().Add(-1 * time.Hour).Unix() + msg := interfaces.KafkaMessage{ + Topic: "push-game_apns", + Value: []byte(fmt.Sprintf(`{ "Payload": { "aps" : { "alert" : "Hello HTTP/2" } }, "Metadata": { "expiration": 0 }, "push_expiry": %d }`, expiration)), + } + + s.mockRateLimiter.EXPECT(). + Allow(gomock.Any(), gomock.Any(), s.appName, "apns"). + Return(true) + + s.waitGroup.Add(1) + s.handler.HandleMessages(context.Background(), msg) + + waitWG(s.T(), s.waitGroup) + }) + + s.Run("should fail if rate limit reached", func() { + expiration := time.Now().Add(1 * time.Hour).UnixNano() + token := uuid.NewV4().String() + msg := interfaces.KafkaMessage{ + Topic: "push-game_apns", + Value: []byte(fmt.Sprintf(`{"DeviceToken": "%s", "Payload": { "aps" : { "alert" : "Hello HTTP/2" } }, "Metadata": { "expiration": 0 }, "push_expiry": %d }`, + token, + expiration, + )), + } + + s.mockRateLimiter.EXPECT(). + Allow(gomock.Any(), token, s.appName, "apns"). + Return(false) + + s.mockStatsReporter.EXPECT(). + NotificationRateLimitReached(s.appName, "apns"). + Return() + + s.waitGroup.Add(1) + s.handler.HandleMessages(context.Background(), msg) + waitWG(s.T(), s.waitGroup) + }) + + s.Run("should succeed", func() { + expiration := time.Now().Add(1 * time.Hour).UnixNano() + token := uuid.NewV4().String() + msg := interfaces.KafkaMessage{ + Topic: "push-game_apns", + Value: []byte(fmt.Sprintf(`{"DeviceToken": "%s", "Payload": { "aps" : { "alert" : "Hello HTTP/2" } }, "Metadata": { "expiration": 0 }, "push_expiry": %d }`, + token, + expiration), + ), + } + + s.mockRateLimiter.EXPECT(). + Allow(gomock.Any(), token, s.appName, "apns"). + Return(true) + + s.mockApnsPushQueue.EXPECT(). + Push(gomock.Any()). + Do(func(n *structs.ApnsNotification) { + assert.Equal(s.T(), s.topic, n.Topic) + assert.Equal(s.T(), token, n.DeviceToken) + assert.Equal(s.T(), 1, n.SendAttempts) + }) + + s.mockStatsReporter.EXPECT(). + ReportSendNotificationLatency(gomock.Any(), s.appName, "apns", gomock.Any()).Return() + + s.mockStatsReporter.EXPECT(). + HandleNotificationSent(s.appName, "apns"). + Return() + + s.handler.HandleMessages(context.Background(), msg) + }) + + s.Run("should succeed and send metadata on notification", func() { + msg := interfaces.KafkaMessage{ + Topic: "push-game_apns", + Value: []byte(`{ "Payload": { "aps" : { "alert" : "Hello HTTP/2" } }, "Metadata": { "some": "data" }}`), + } + + s.mockRateLimiter.EXPECT(). + Allow(gomock.Any(), gomock.Any(), s.appName, "apns"). + Return(true) + + s.mockApnsPushQueue.EXPECT(). + Push(gomock.Any()). + Do(func(n *structs.ApnsNotification) { + bytes, ok := n.Notification.Payload.([]byte) + assert.True(s.T(), ok) + + var payload map[string]interface{} + err := json.Unmarshal(bytes, &payload) + assert.NoError(s.T(), err) + + metadata, ok := payload["M"].(map[string]interface{}) + assert.True(s.T(), ok) + assert.Equal(s.T(), "data", metadata["some"]) + }) + + s.mockStatsReporter.EXPECT(). + ReportSendNotificationLatency(gomock.Any(), s.appName, "apns", gomock.Any()).Return() + + s.mockStatsReporter.EXPECT(). + HandleNotificationSent(s.appName, "apns"). + Return() + + s.handler.HandleMessages(context.Background(), msg) + }) + + s.Run("should succeed and merge metadata on notification", func() { + msg := interfaces.KafkaMessage{ + Topic: "push-game_apns", + Value: []byte(`{ "Payload": { "aps" : { "alert" : "Hello HTTP/2" }, "M": { "previous": "value" }}, "Metadata": { "some": "data" } }`), + } + + s.mockRateLimiter.EXPECT(). + Allow(gomock.Any(), gomock.Any(), s.appName, "apns"). + Return(true) + + s.mockApnsPushQueue.EXPECT(). + Push(gomock.Any()). + Do(func(n *structs.ApnsNotification) { + bytes, ok := n.Notification.Payload.([]byte) + assert.True(s.T(), ok) + + var payload map[string]interface{} + err := json.Unmarshal(bytes, &payload) + assert.NoError(s.T(), err) + + metadata, ok := payload["M"].(map[string]interface{}) + assert.True(s.T(), ok) + assert.Equal(s.T(), "value", metadata["previous"]) + assert.Equal(s.T(), "data", metadata["some"]) + }) + + s.mockStatsReporter.EXPECT(). + ReportSendNotificationLatency(gomock.Any(), s.appName, "apns", gomock.Any()).Return() + + s.mockStatsReporter.EXPECT(). + HandleNotificationSent(s.appName, "apns"). + Return() + + s.handler.HandleMessages(context.Background(), msg) + }) + + s.Run("should succeed and handle nested metadata on notification", func() { + msg := interfaces.KafkaMessage{ + Topic: "push-game_apns", + Value: []byte(`{ "Payload": { "aps" : { "alert" : "Hello HTTP/2" }}, "Metadata": { "nested": { "some": "data" }}}`), + } + + s.mockRateLimiter.EXPECT(). + Allow(gomock.Any(), gomock.Any(), s.appName, "apns"). + Return(true) + + s.mockApnsPushQueue.EXPECT(). + Push(gomock.Any()). + Do(func(n *structs.ApnsNotification) { + bytes, ok := n.Notification.Payload.([]byte) + assert.True(s.T(), ok) + + var payload map[string]interface{} + err := json.Unmarshal(bytes, &payload) + assert.NoError(s.T(), err) + + metadata, ok := payload["M"].(map[string]interface{}) + assert.True(s.T(), ok) + nested, ok := metadata["nested"].(map[string]interface{}) + assert.True(s.T(), ok) + assert.Equal(s.T(), "data", nested["some"]) + }) + + s.mockStatsReporter.EXPECT(). + ReportSendNotificationLatency(gomock.Any(), s.appName, "apns", gomock.Any()).Return() + + s.mockStatsReporter.EXPECT(). + HandleNotificationSent(s.appName, "apns"). + Return() + + s.handler.HandleMessages(context.Background(), msg) + }) +} + +func (s *ApnsMessageHandlerTestSuite) TestResponseHandle() { + s.Run("should send ack metric if response has no error", func() { + apnsID := uuid.NewV4().String() + res := &structs.ResponseWithMetadata{ + StatusCode: 200, + ApnsID: apnsID, + Notification: &structs.ApnsNotification{ + Notification: apns2.Notification{ + ApnsID: apnsID, + }, + }, + } + + s.mockStatsReporter.EXPECT(). + HandleNotificationSuccess(s.appName, "apns") + + s.waitGroup.Add(1) + + err := s.handler.handleAPNSResponse(res) + s.NoError(err) + waitWG(s.T(), s.waitGroup) + }) + + for _, r := range errorReasons { + if r == apns2.ReasonTooManyRequests { + continue + } + s.Run(fmt.Sprintf("should send feedback and failure metric if response is %s", r), func() { + apnsID := uuid.NewV4().String() + res := &structs.ResponseWithMetadata{ + StatusCode: 400, + ApnsID: apnsID, + Reason: r, + Notification: &structs.ApnsNotification{ + Notification: apns2.Notification{ + ApnsID: apnsID, + }, + }, + } + + s.mockStatsReporter.EXPECT(). + HandleNotificationFailure(s.appName, "apns", errors.NewPushError(s.handler.mapErrorReason(res.Reason), res.Reason)). + Return() + + s.mockFeedbackReporter.EXPECT(). + SendFeedback(s.appName, "apns", gomock.Any()). + Do(func(game, platform string, feedback []byte) { + var actualRes structs.ResponseWithMetadata + err := json.Unmarshal(feedback, &actualRes) + s.NoError(err) + s.Equal(apnsID, res.ApnsID) + + if r == apns2.ReasonBadDeviceToken { + deleteToken, ok := actualRes.Metadata["deleteToken"].(bool) + s.True(ok) + s.True(deleteToken) + } + }) + + s.waitGroup.Add(1) + + err := s.handler.handleAPNSResponse(res) + s.NoError(err) + waitWG(s.T(), s.waitGroup) + + }) + } + + s.Run("should retry if TooManyRequests", func() { + apnsID := uuid.NewV4().String() + res := &structs.ResponseWithMetadata{ + StatusCode: 429, + ApnsID: apnsID, + Reason: apns2.ReasonTooManyRequests, + Notification: &structs.ApnsNotification{ + Notification: apns2.Notification{ + ApnsID: apnsID, + }, + }, + } + + s.mockApnsPushQueue.EXPECT(). + Push(res.Notification). + Return() + + s.mockStatsReporter.EXPECT(). + ReportSendNotificationLatency(gomock.Any(), s.appName, "apns", gomock.Any()). + Return() + + err := s.handler.handleAPNSResponse(res) + s.NoError(err) + }) +} + +func waitWG(t *testing.T, wg *sync.WaitGroup) { + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + timeout := time.After(10 * time.Millisecond) + select { + case <-done: + case <-timeout: + t.Fatal("timed out waiting for waitgroup") + } +} + +var errorReasons = []string{ + apns2.ReasonPayloadEmpty, + apns2.ReasonPayloadTooLarge, + apns2.ReasonMissingDeviceToken, + apns2.ReasonBadDeviceToken, + apns2.ReasonTooManyRequests, + apns2.ReasonBadMessageID, + apns2.ReasonBadExpirationDate, + apns2.ReasonBadPriority, + apns2.ReasonBadTopic, + apns2.ReasonBadCertificate, + apns2.ReasonBadCertificateEnvironment, + apns2.ReasonForbidden, + apns2.ReasonMissingTopic, + apns2.ReasonTopicDisallowed, + apns2.ReasonUnregistered, + apns2.ReasonDeviceTokenNotForTopic, + apns2.ReasonDuplicateHeaders, + apns2.ReasonBadPath, + apns2.ReasonMethodNotAllowed, + apns2.ReasonIdleTimeout, + apns2.ReasonShutdown, + apns2.ReasonInternalServerError, + apns2.ReasonServiceUnavailable, + apns2.ReasonExpiredProviderToken, + apns2.ReasonInvalidProviderToken, + apns2.ReasonMissingProviderToken, +} diff --git a/extensions/apns_push_queue.go b/extensions/apns/apns_push_queue.go similarity index 99% rename from extensions/apns_push_queue.go rename to extensions/apns/apns_push_queue.go index 87b8682..bd32c87 100644 --- a/extensions/apns_push_queue.go +++ b/extensions/apns/apns_push_queue.go @@ -20,7 +20,7 @@ * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ -package extensions +package apns import ( "github.com/sideshow/apns2" diff --git a/extensions/apns_push_queue_test.go b/extensions/apns/apns_push_queue_test.go similarity index 99% rename from extensions/apns_push_queue_test.go rename to extensions/apns/apns_push_queue_test.go index 6ce99d0..af57873 100644 --- a/extensions/apns_push_queue_test.go +++ b/extensions/apns/apns_push_queue_test.go @@ -20,7 +20,7 @@ * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ -package extensions +package apns import ( "os" diff --git a/extensions/apns_message_handler_test.go b/extensions/apns_message_handler_test.go deleted file mode 100644 index 95766d7..0000000 --- a/extensions/apns_message_handler_test.go +++ /dev/null @@ -1,882 +0,0 @@ -/* - * Copyright (c) 2016 TFG Co - * Author: TFG Co - * - * Permission is hereby granted, free of charge, to any person obtaining a copy of - * this software and associated documentation files (the "Software"), to deal in - * the Software without restriction, including without limitation the rights to - * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of - * the Software, and to permit persons to whom the Software is furnished to do so, - * subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS - * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR - * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER - * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN - * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - */ - -package extensions - -import ( - "context" - "encoding/json" - "fmt" - "os" - "time" - - uuid "github.com/satori/go.uuid" - "github.com/sideshow/apns2" - mock_interfaces "github.com/topfreegames/pusher/mocks/interfaces" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" - "github.com/sirupsen/logrus" - "github.com/sirupsen/logrus/hooks/test" - "github.com/topfreegames/pusher/interfaces" - "github.com/topfreegames/pusher/mocks" - "github.com/topfreegames/pusher/structs" - "github.com/topfreegames/pusher/util" -) - -var _ = Describe("APNS Message Handler", func() { - var db interfaces.DB - var feedbackClients []interfaces.FeedbackReporter - var handler *APNSMessageHandler - var mockKafkaProducerClient *mocks.KafkaProducerClientMock - var mockPushQueue *mocks.APNSPushQueueMock - var mockStatsDClient *mocks.StatsDClientMock - var statsClients []interfaces.StatsReporter - mockConsumptionManager := mock_interfaces.NewMockConsumptionManager() - mockRateLimiter := mocks.NewRateLimiterMock() - ctx := context.Background() - - configFile := os.Getenv("CONFIG_FILE") - if configFile == "" { - configFile = "../config/test.yaml" - } - config, _ := util.NewViperWithConfigFile(configFile) - authKeyPath := "../tls/authkey.p8" - keyID := "ABC123DEFG" - teamID := "DEF123GHIJ" - topic := "com.game.test" - appName := "game" - isProduction := false - logger, hook := test.NewNullLogger() - logger.Level = logrus.DebugLevel - - Describe("[Unit]", func() { - BeforeEach(func() { - mockStatsDClient = mocks.NewStatsDClientMock() - mockKafkaProducerClient = mocks.NewKafkaProducerClientMock() - mockKafkaProducerClient.StartConsumingMessagesInProduceChannel() - c, err := NewStatsD(config, logger, mockStatsDClient) - Expect(err).NotTo(HaveOccurred()) - - kc, err := NewKafkaProducer(config, logger, mockKafkaProducerClient) - Expect(err).NotTo(HaveOccurred()) - - statsClients = []interfaces.StatsReporter{c} - feedbackClients = []interfaces.FeedbackReporter{kc} - - db = mocks.NewPGMock(0, 1) - - mockPushQueue = mocks.NewAPNSPushQueueMock() - handler, err = NewAPNSMessageHandler( - authKeyPath, - keyID, - teamID, - topic, - appName, - isProduction, - config, - logger, - nil, - statsClients, - feedbackClients, - mockPushQueue, - mockConsumptionManager, - mockRateLimiter, - ) - Expect(err).NotTo(HaveOccurred()) - db.(*mocks.PGMock).RowsReturned = 0 - - hook.Reset() - }) - - Describe("Creating new handler", func() { - It("should return configured handler", func() { - Expect(handler).NotTo(BeNil()) - Expect(handler.Config).NotTo(BeNil()) - Expect(handler.IsProduction).To(Equal(isProduction)) - apnsResMutex.Lock() - Expect(handler.responsesReceived).To(Equal(int64(0))) - Expect(handler.sentMessages).To(Equal(int64(0))) - apnsResMutex.Unlock() - }) - }) - - Describe("Handle APNS response", func() { - It("if response has nil error", func() { - apnsID := uuid.NewV4().String() - res := &structs.ResponseWithMetadata{ - StatusCode: 200, - ApnsID: apnsID, - Notification: &structs.ApnsNotification{ - Notification: apns2.Notification{ - ApnsID: apnsID, - }, - }, - } - handler.handleAPNSResponse(res) - apnsResMutex.Lock() - Expect(handler.responsesReceived).To(Equal(int64(1))) - Expect(handler.successesReceived).To(Equal(int64(1))) - apnsResMutex.Unlock() - }) - - It("if response has error push.ReasonUnregistered", func() { - apnsID := uuid.NewV4().String() - res := &structs.ResponseWithMetadata{ - StatusCode: 400, - ApnsID: apnsID, - Reason: apns2.ReasonUnregistered, - Notification: &structs.ApnsNotification{ - Notification: apns2.Notification{ - ApnsID: apnsID, - }, - }, - } - handler.handleAPNSResponse(res) - apnsResMutex.Lock() - Expect(handler.responsesReceived).To(Equal(int64(1))) - Expect(handler.failuresReceived).To(Equal(int64(1))) - apnsResMutex.Unlock() - }) - - It("if response has error push.ErrBadDeviceToken", func() { - apnsID := uuid.NewV4().String() - res := &structs.ResponseWithMetadata{ - StatusCode: 400, - ApnsID: apnsID, - Reason: apns2.ReasonBadDeviceToken, - Notification: &structs.ApnsNotification{ - Notification: apns2.Notification{ - ApnsID: apnsID, - }, - }, - } - handler.handleAPNSResponse(res) - apnsResMutex.Lock() - Expect(handler.responsesReceived).To(Equal(int64(1))) - Expect(handler.failuresReceived).To(Equal(int64(1))) - apnsResMutex.Unlock() - }) - - It("if response has error push.ErrBadCertificate", func() { - apnsID := uuid.NewV4().String() - res := &structs.ResponseWithMetadata{ - StatusCode: 400, - ApnsID: apnsID, - Reason: apns2.ReasonBadCertificate, - Notification: &structs.ApnsNotification{ - Notification: apns2.Notification{ - ApnsID: apnsID, - }, - }, - } - handler.handleAPNSResponse(res) - apnsResMutex.Lock() - Expect(handler.responsesReceived).To(Equal(int64(1))) - Expect(handler.failuresReceived).To(Equal(int64(1))) - apnsResMutex.Unlock() - }) - - It("if response has error push.ErrBadCertificateEnvironment", func() { - apnsID := uuid.NewV4().String() - res := &structs.ResponseWithMetadata{ - StatusCode: 400, - ApnsID: apnsID, - Reason: apns2.ReasonBadCertificateEnvironment, - Notification: &structs.ApnsNotification{ - Notification: apns2.Notification{ - ApnsID: apnsID, - }, - }, - } - handler.handleAPNSResponse(res) - apnsResMutex.Lock() - Expect(handler.responsesReceived).To(Equal(int64(1))) - Expect(handler.failuresReceived).To(Equal(int64(1))) - apnsResMutex.Unlock() - }) - - It("if response has error push.ErrForbidden", func() { - apnsID := uuid.NewV4().String() - res := &structs.ResponseWithMetadata{ - StatusCode: 400, - ApnsID: apnsID, - Reason: apns2.ReasonForbidden, - Notification: &structs.ApnsNotification{ - Notification: apns2.Notification{ - ApnsID: apnsID, - }, - }, - } - handler.handleAPNSResponse(res) - apnsResMutex.Lock() - Expect(handler.responsesReceived).To(Equal(int64(1))) - Expect(handler.failuresReceived).To(Equal(int64(1))) - apnsResMutex.Unlock() - }) - - It("if response has error push.ErrMissingTopic", func() { - apnsID := uuid.NewV4().String() - res := &structs.ResponseWithMetadata{ - StatusCode: 400, - ApnsID: apnsID, - Reason: apns2.ReasonMissingTopic, - Notification: &structs.ApnsNotification{ - Notification: apns2.Notification{ - ApnsID: apnsID, - }, - }, - } - handler.handleAPNSResponse(res) - apnsResMutex.Lock() - Expect(handler.responsesReceived).To(Equal(int64(1))) - Expect(handler.failuresReceived).To(Equal(int64(1))) - apnsResMutex.Unlock() - }) - - It("if response has error push.ErrTopicDisallowed", func() { - apnsID := uuid.NewV4().String() - res := &structs.ResponseWithMetadata{ - StatusCode: 400, - ApnsID: apnsID, - Reason: apns2.ReasonTopicDisallowed, - Notification: &structs.ApnsNotification{ - Notification: apns2.Notification{ - ApnsID: apnsID, - }, - }, - } - handler.handleAPNSResponse(res) - apnsResMutex.Lock() - Expect(handler.responsesReceived).To(Equal(int64(1))) - Expect(handler.failuresReceived).To(Equal(int64(1))) - apnsResMutex.Unlock() - }) - - It("if response has error push.ErrDeviceTokenNotForTopic", func() { - apnsID := uuid.NewV4().String() - res := &structs.ResponseWithMetadata{ - StatusCode: 400, - ApnsID: apnsID, - Reason: apns2.ReasonDeviceTokenNotForTopic, - Notification: &structs.ApnsNotification{ - Notification: apns2.Notification{ - ApnsID: apnsID, - }, - }, - } - handler.handleAPNSResponse(res) - apnsResMutex.Lock() - Expect(handler.responsesReceived).To(Equal(int64(1))) - Expect(handler.failuresReceived).To(Equal(int64(1))) - apnsResMutex.Unlock() - }) - - It("if response has error push.ErrIdleTimeout", func() { - apnsID := uuid.NewV4().String() - res := &structs.ResponseWithMetadata{ - StatusCode: 400, - ApnsID: apnsID, - Reason: apns2.ReasonIdleTimeout, - Notification: &structs.ApnsNotification{ - Notification: apns2.Notification{ - ApnsID: apnsID, - }, - }, - } - handler.handleAPNSResponse(res) - apnsResMutex.Lock() - Expect(handler.responsesReceived).To(Equal(int64(1))) - Expect(handler.failuresReceived).To(Equal(int64(1))) - apnsResMutex.Unlock() - }) - - It("if response has error push.ErrShutdown", func() { - apnsID := uuid.NewV4().String() - res := &structs.ResponseWithMetadata{ - StatusCode: 400, - ApnsID: apnsID, - Reason: apns2.ReasonShutdown, - Notification: &structs.ApnsNotification{ - Notification: apns2.Notification{ - ApnsID: apnsID, - }, - }, - } - handler.handleAPNSResponse(res) - apnsResMutex.Lock() - Expect(handler.responsesReceived).To(Equal(int64(1))) - Expect(handler.failuresReceived).To(Equal(int64(1))) - apnsResMutex.Unlock() - }) - - It("if response has error push.ErrInternalServerError", func() { - apnsID := uuid.NewV4().String() - res := &structs.ResponseWithMetadata{ - StatusCode: 400, - ApnsID: apnsID, - Reason: apns2.ReasonInternalServerError, - Notification: &structs.ApnsNotification{ - Notification: apns2.Notification{ - ApnsID: apnsID, - }, - }, - } - handler.handleAPNSResponse(res) - apnsResMutex.Lock() - Expect(handler.responsesReceived).To(Equal(int64(1))) - Expect(handler.failuresReceived).To(Equal(int64(1))) - apnsResMutex.Unlock() - }) - - It("if response has error push.ErrServiceUnavailable", func() { - apnsID := uuid.NewV4().String() - res := &structs.ResponseWithMetadata{ - StatusCode: 400, - ApnsID: apnsID, - Reason: apns2.ReasonServiceUnavailable, - Notification: &structs.ApnsNotification{ - Notification: apns2.Notification{ - ApnsID: apnsID, - }, - }, - } - handler.handleAPNSResponse(res) - apnsResMutex.Lock() - Expect(handler.responsesReceived).To(Equal(int64(1))) - Expect(handler.failuresReceived).To(Equal(int64(1))) - apnsResMutex.Unlock() - }) - - It("if response has untracked error", func() { - apnsID := uuid.NewV4().String() - res := &structs.ResponseWithMetadata{ - StatusCode: 400, - ApnsID: apnsID, - Reason: apns2.ReasonMethodNotAllowed, - Notification: &structs.ApnsNotification{ - Notification: apns2.Notification{ - ApnsID: apnsID, - }, - }, - } - handler.handleAPNSResponse(res) - apnsResMutex.Lock() - Expect(handler.responsesReceived).To(Equal(int64(1))) - Expect(handler.failuresReceived).To(Equal(int64(1))) - apnsResMutex.Unlock() - }) - }) - - Describe("Send notification", func() { - It("should add message to push queue", func() { - m, err := handler.parseKafkaMessage(interfaces.KafkaMessage{ - Topic: "push-game_apns", - Value: []byte(`{ "Payload": { "aps" : { "alert" : "Hello HTTP/2" } }, "Metadata": { "some": "data" } }`), - }) - Expect(err).To(BeNil()) - - n, err := handler.buildAndValidateNotification(m) - - handler.sendNotification(n) - res := mockPushQueue.PushedNotification - Expect(res).NotTo(BeNil()) - }) - - It("should have metadata on message sent to push queue", func() { - m, err := handler.parseKafkaMessage(interfaces.KafkaMessage{ - Topic: "push-game_apns", - Value: []byte(`{ "Payload": { "aps" : { "alert" : "Hello HTTP/2" } }, "Metadata": { "some": "data" } }`), - }) - Expect(err).To(BeNil()) - - n, err := handler.buildAndValidateNotification(m) - Expect(err).To(BeNil()) - - handler.sendNotification(n) - - sentMessage := mockPushQueue.PushedNotification - bytes, ok := sentMessage.Notification.Payload.([]byte) - Expect(ok).To(BeTrue()) - var payload map[string]interface{} - err = json.Unmarshal(bytes, &payload) - Expect(err).To(BeNil()) - Expect(payload).NotTo(BeNil()) - Expect(payload["M"]).NotTo(BeNil()) - - metadata, ok := payload["M"].(map[string]interface{}) - Expect(ok).To(BeTrue()) - Expect(metadata).NotTo(BeNil()) - Expect(metadata["some"]).To(Equal("data")) - }) - - It("should merge metadata on message sent to push queue", func() { - m, err := handler.parseKafkaMessage(interfaces.KafkaMessage{ - Topic: "push-game_apns", - Value: []byte(`{ "Payload": { "aps" : { "alert" : "Hello HTTP/2" }, "M": { "metadata": "received" } }, "Metadata": { "some": "data" } }`), - }) - Expect(err).To(BeNil()) - - n, err := handler.buildAndValidateNotification(m) - Expect(err).To(BeNil()) - - handler.sendNotification(n) - - sentMessage := mockPushQueue.PushedNotification - bytes, ok := sentMessage.Payload.([]byte) - Expect(ok).To(BeTrue()) - var payload map[string]interface{} - json.Unmarshal(bytes, &payload) - - Expect(payload).NotTo(BeNil()) - Expect(payload["M"]).NotTo(BeNil()) - - metadata, ok := payload["M"].(map[string]interface{}) - Expect(ok).To(BeTrue()) - Expect(metadata).NotTo(BeNil()) - Expect(metadata["metadata"]).To(Equal("received")) - Expect(metadata["some"]).To(Equal("data")) - }) - - It("should have nested metadata on message sent to push queue", func() { - m, err := handler.parseKafkaMessage(interfaces.KafkaMessage{ - Topic: "push-game_apns", - Value: []byte(`{ "Payload": { "aps" : { "alert" : "Hello HTTP/2" }, "M": { "metadata": "received" } }, "Metadata": { "nested": { "some": "data"} } }`), - }) - Expect(err).To(BeNil()) - - n, err := handler.buildAndValidateNotification(m) - Expect(err).To(BeNil()) - - handler.sendNotification(n) - - sentMessage := mockPushQueue.PushedNotification - bytes, ok := sentMessage.Payload.([]byte) - Expect(ok).To(BeTrue()) - var payload map[string]interface{} - json.Unmarshal(bytes, &payload) - - Expect(payload).NotTo(BeNil()) - Expect(payload["M"]).NotTo(BeNil()) - - metadata, ok := payload["M"].(map[string]interface{}) - Expect(ok).To(BeTrue()) - Expect(metadata).NotTo(BeNil()) - Expect(metadata["metadata"]).To(Equal("received")) - nestedMetadata, ok := metadata["nested"].(map[string]interface{}) - Expect(ok).To(BeTrue()) - Expect(nestedMetadata["some"]).To(Equal("data")) - }) - }) - - Describe("Log Stats", func() { - It("should log and zero stats", func() { - handler.sentMessages = 100 - handler.responsesReceived = 90 - handler.successesReceived = 60 - handler.failuresReceived = 30 - Expect(func() { go handler.LogStats() }).ShouldNot(Panic()) - time.Sleep(2 * handler.LogStatsInterval) - - apnsResMutex.Lock() - Eventually(func() int64 { return handler.sentMessages }).Should(Equal(int64(0))) - Eventually(func() int64 { return handler.responsesReceived }).Should(Equal(int64(0))) - Eventually(func() int64 { return handler.successesReceived }).Should(Equal(int64(0))) - Eventually(func() int64 { return handler.failuresReceived }).Should(Equal(int64(0))) - apnsResMutex.Unlock() - }) - }) - - Describe("Stats Reporter sent message", func() { - It("should call HandleNotificationSent upon message sent to queue", func() { - Expect(handler).NotTo(BeNil()) - Expect(handler.StatsReporters).To(Equal(statsClients)) - kafkaMessage := interfaces.KafkaMessage{ - Game: "game", - Topic: "push-game_apns", - Value: []byte(`{ "aps" : { "alert" : "Hello HTTP/2" } }`), - } - handler.HandleMessages(ctx, kafkaMessage) - handler.HandleMessages(ctx, kafkaMessage) - - Expect(mockStatsDClient.Counts["sent"]).To(Equal(int64(2))) - }) - - It("should call HandleNotificationSuccess upon message response received", func() { - Expect(handler).NotTo(BeNil()) - Expect(handler.StatsReporters).To(Equal(statsClients)) - - apnsID := uuid.NewV4().String() - res := &structs.ResponseWithMetadata{ - StatusCode: 200, - ApnsID: apnsID, - Notification: &structs.ApnsNotification{Notification: apns2.Notification{ApnsID: apnsID}}, - } - - handler.handleAPNSResponse(res) - handler.handleAPNSResponse(res) - Expect(mockStatsDClient.Counts["ack"]).To(Equal(int64(2))) - }) - - It("should call HandleNotificationFailure upon message response received", func() { - Expect(handler).NotTo(BeNil()) - Expect(handler.StatsReporters).To(Equal(statsClients)) - - res := &structs.ResponseWithMetadata{ - StatusCode: 400, - ApnsID: uuid.NewV4().String(), - Reason: apns2.ReasonMissingDeviceToken, - Notification: &structs.ApnsNotification{}, - } - handler.handleAPNSResponse(res) - handler.handleAPNSResponse(res) - - Expect(mockStatsDClient.Counts["failed"]).To(Equal(int64(2))) - }) - }) - - Describe("Feedback Reporter sent message", func() { - BeforeEach(func() { - mockKafkaProducerClient = mocks.NewKafkaProducerClientMock() - kc, err := NewKafkaProducer(config, logger, mockKafkaProducerClient) - Expect(err).NotTo(HaveOccurred()) - - feedbackClients = []interfaces.FeedbackReporter{kc} - - db = mocks.NewPGMock(0, 1) - - handler, err = NewAPNSMessageHandler( - authKeyPath, - keyID, - teamID, - topic, - appName, - isProduction, - config, - logger, - nil, - statsClients, - feedbackClients, - mockPushQueue, - mockConsumptionManager, - mockRateLimiter, - ) - Expect(err).NotTo(HaveOccurred()) - }) - It("should include a timestamp in feedback root and the hostname in metadata", func() { - timestampNow := time.Now().Unix() - hostname, err := os.Hostname() - Expect(err).NotTo(HaveOccurred()) - metadata := map[string]interface{}{ - "some": "metadata", - "timestamp": timestampNow, - "hostname": hostname, - "game": "game", - "platform": "apns", - } - res := &structs.ResponseWithMetadata{ - StatusCode: 200, - ApnsID: "idTest1", - Notification: &structs.ApnsNotification{ - Notification: apns2.Notification{ - ApnsID: "idTest1", - }, - Metadata: metadata, - }, - } - - go handler.handleAPNSResponse(res) - - fromKafka := &structs.ResponseWithMetadata{} - msg := <-mockKafkaProducerClient.ProduceChannel() - json.Unmarshal(msg.Value, fromKafka) - Expect(fromKafka.Timestamp).To(Equal(timestampNow)) - Expect(fromKafka.Metadata["hostname"]).To(Equal(hostname)) - }) - - It("should send feedback if success and metadata is present", func() { - metadata := map[string]interface{}{ - "some": "metadata", - "timestamp": time.Now().Unix(), - "game": "game", - "platform": "apns", - } - res := &structs.ResponseWithMetadata{ - StatusCode: 200, - ApnsID: "idTest1", - Notification: &structs.ApnsNotification{ - Notification: apns2.Notification{ - ApnsID: "idTest1", - }, - Metadata: metadata, - }, - } - go handler.handleAPNSResponse(res) - - fromKafka := &structs.ResponseWithMetadata{} - msg := <-mockKafkaProducerClient.ProduceChannel() - json.Unmarshal(msg.Value, fromKafka) - Expect(fromKafka.ApnsID).To(Equal(res.ApnsID)) - Expect(fromKafka.Metadata["some"]).To(Equal(metadata["some"])) - }) - - It("should send feedback if success and metadata is not present", func() { - res := &structs.ResponseWithMetadata{ - StatusCode: 200, - ApnsID: "idTest1", - Notification: &structs.ApnsNotification{ - Notification: apns2.Notification{ - ApnsID: "idTest1", - }, - }, - Metadata: map[string]interface{}{ - "timestamp": int64(0), - }, - } - - go handler.handleAPNSResponse(res) - - fromKafka := &structs.ResponseWithMetadata{} - msg := <-mockKafkaProducerClient.ProduceChannel() - json.Unmarshal(msg.Value, fromKafka) - Expect(fromKafka.ApnsID).To(Equal(res.ApnsID)) - Expect(fromKafka.Metadata["some"]).To(BeNil()) - }) - - It("should send feedback if error and metadata is present and token should be deleted", func() { - metadata := map[string]interface{}{ - "some": "metadata", - "timestamp": time.Now().Unix(), - "game": "game", - "platform": "apns", - } - res := &structs.ResponseWithMetadata{ - StatusCode: 400, - ApnsID: "idTest1", - Reason: apns2.ReasonBadDeviceToken, - Notification: &structs.ApnsNotification{ - Notification: apns2.Notification{ - ApnsID: "idTest1", - }, - Metadata: metadata, - }, - } - go handler.handleAPNSResponse(res) - - fromKafka := &structs.ResponseWithMetadata{} - msg := <-mockKafkaProducerClient.ProduceChannel() - json.Unmarshal(msg.Value, fromKafka) - Expect(fromKafka.ApnsID).To(Equal(res.ApnsID)) - Expect(fromKafka.Metadata["some"]).To(Equal(metadata["some"])) - Expect(fromKafka.Metadata["deleteToken"]).To(BeTrue()) - Expect(string(msg.Value)).To(ContainSubstring("BadDeviceToken")) - }) - - It("should send feedback if error and metadata is present and token should not be deleted", func() { - metadata := map[string]interface{}{ - "some": "metadata", - "timestamp": time.Now().Unix(), - "game": "game", - "platform": "apns", - } - - res := &structs.ResponseWithMetadata{ - StatusCode: 400, - ApnsID: "idTest1", - Reason: apns2.ReasonBadMessageID, - Notification: &structs.ApnsNotification{ - Notification: apns2.Notification{ - ApnsID: "idTest1", - }, - Metadata: metadata, - }, - } - go handler.handleAPNSResponse(res) - - fromKafka := &structs.ResponseWithMetadata{} - msg := <-mockKafkaProducerClient.ProduceChannel() - json.Unmarshal(msg.Value, fromKafka) - Expect(fromKafka.ApnsID).To(Equal(res.ApnsID)) - Expect(fromKafka.Metadata["some"]).To(Equal(metadata["some"])) - Expect(fromKafka.Metadata["deleteToken"]).To(BeNil()) - Expect(string(msg.Value)).To(ContainSubstring("BadMessageId")) - }) - - It("should send feedback if error and metadata is not present", func() { - res := &structs.ResponseWithMetadata{ - DeviceToken: uuid.NewV4().String(), - StatusCode: 400, - ApnsID: "idTest1", - Reason: apns2.ReasonBadDeviceToken, - Notification: &structs.ApnsNotification{}, - } - go handler.handleAPNSResponse(res) - - fromKafka := &structs.ResponseWithMetadata{} - msg := <-mockKafkaProducerClient.ProduceChannel() - json.Unmarshal(msg.Value, fromKafka) - Expect(fromKafka.ApnsID).To(Equal(res.ApnsID)) - Expect(fromKafka.Metadata).To(BeNil()) - Expect(string(msg.Value)).To(ContainSubstring("BadDeviceToken")) - }) - - It("should not deadlock on handle retry for handle apns response", func() { - metadata := map[string]interface{}{ - "some": "metadata", - "timestamp": time.Now().Unix(), - "game": "game", - "platform": "apns", - } - - res := &structs.ResponseWithMetadata{ - StatusCode: 429, - ApnsID: "idTest1", - Reason: apns2.ReasonTooManyRequests, - Notification: &structs.ApnsNotification{ - Notification: apns2.Notification{ - ApnsID: "idTest1", - }, - }, - Metadata: metadata, - } - - res2 := &structs.ResponseWithMetadata{ - StatusCode: 429, - ApnsID: "idTest2", - Reason: apns2.ReasonTooManyRequests, - Notification: &structs.ApnsNotification{ - Notification: apns2.Notification{ - ApnsID: "idTest2", - }, - }, - Metadata: metadata, - } - go func() { - defer GinkgoRecover() - err := handler.handleAPNSResponse(res) - Expect(err).NotTo(HaveOccurred()) - }() - go func() { - defer GinkgoRecover() - err := handler.handleAPNSResponse(res2) - Expect(err).NotTo(HaveOccurred()) - }() - }) - }) - - Describe("Cleanup", func() { - It("should close PushQueue without error", func() { - err := handler.Cleanup() - Expect(err).NotTo(HaveOccurred()) - Expect(handler.PushQueue.(*mocks.APNSPushQueueMock).Closed).To(BeTrue()) - }) - }) - }) - Describe("[Integration]", func() { - BeforeEach(func() { - var err error - handler, err = NewAPNSMessageHandler( - authKeyPath, - keyID, - teamID, - topic, - appName, - isProduction, - config, - logger, - nil, - nil, - nil, - nil, - nil, - mockRateLimiter, - ) - Expect(err).NotTo(HaveOccurred()) - hook.Reset() - }) - - Describe("Send message", func() { - It("should add message to push queue and increment sentMessages", func() { - handler.HandleMessages(ctx, interfaces.KafkaMessage{ - Topic: "push-game_apns", - Value: []byte(`{ "aps" : { "alert" : "Hello HTTP/2" } }`), - }) - - apnsResMutex.Lock() - Expect(handler.ignoredMessages).To(Equal(int64(0))) - Eventually(handler.PushQueue.ResponseChannel(), 5*time.Second).Should(Receive()) - Expect(handler.sentMessages).To(Equal(int64(1))) - apnsResMutex.Unlock() - }) - - It("should be able to call HandleMessages concurrently with no errors", func() { - msg := interfaces.KafkaMessage{ - Topic: "push-game_apns", - Value: []byte(`{ "aps" : { "alert" : "Hello" } }`), - } - - go handler.HandleMessages(context.Background(), msg) - go handler.HandleMessages(context.Background(), msg) - go handler.HandleMessages(context.Background(), msg) - }) - }) - - Describe("PushExpiry", func() { - It("should not send message if PushExpiry is in the past", func() { - handler.HandleMessages(ctx, interfaces.KafkaMessage{ - Topic: "push-game_apns", - Value: []byte(fmt.Sprintf(`{ "aps" : { "alert" : "Hello HTTP/2" }, "push_expiry": %d }`, MakeTimestamp()-int64(100))), - }) - Eventually(handler.PushQueue.ResponseChannel(), 100*time.Millisecond).ShouldNot(Receive()) - apnsResMutex.Lock() - Expect(handler.sentMessages).To(Equal(int64(0))) - Expect(handler.ignoredMessages).To(Equal(int64(1))) - apnsResMutex.Unlock() - }) - It("should send message if PushExpiry is in the future", func() { - handler.HandleMessages(ctx, interfaces.KafkaMessage{ - Topic: "push-game_apns", - Value: []byte(fmt.Sprintf(`{ "aps" : { "alert" : "Hello HTTP/2" }, "push_expiry": %d}`, MakeTimestamp()+int64(100))), - }) - Eventually(handler.PushQueue.ResponseChannel(), 100*time.Millisecond).ShouldNot(Receive()) - - apnsResMutex.Lock() - Expect(handler.sentMessages).To(Equal(int64(1))) - apnsResMutex.Unlock() - }) - }) - - Describe("Handle Responses", func() { - It("should be called without panicking", func() { - Expect(func() { go handler.HandleResponses() }).ShouldNot(Panic()) - handler.PushQueue.ResponseChannel() <- &structs.ResponseWithMetadata{ - Notification: &structs.ApnsNotification{}, - } - time.Sleep(50 * time.Millisecond) - apnsResMutex.Lock() - Expect(handler.responsesReceived).To(Equal(int64(1))) - apnsResMutex.Unlock() - }) - }) - - }) -}) diff --git a/extensions/common.go b/extensions/common.go index b933098..bd0ea08 100644 --- a/extensions/common.go +++ b/extensions/common.go @@ -52,7 +52,7 @@ func GetGameAndPlatformFromTopic(topic string) ParsedTopic { return getGameAndPlatformFromTopic(topic) } -func sendToFeedbackReporters(feedbackReporters []interfaces.FeedbackReporter, res interface{}, topic ParsedTopic) error { +func SendToFeedbackReporters(feedbackReporters []interfaces.FeedbackReporter, res interface{}, topic ParsedTopic) error { jres, err := json.Marshal(res) if err != nil { return err @@ -65,19 +65,19 @@ func sendToFeedbackReporters(feedbackReporters []interfaces.FeedbackReporter, re return nil } -func statsReporterHandleNotificationSent(statsReporters []interfaces.StatsReporter, game string, platform string) { +func StatsReporterHandleNotificationSent(statsReporters []interfaces.StatsReporter, game string, platform string) { for _, statsReporter := range statsReporters { statsReporter.HandleNotificationSent(game, platform) } } -func statsReporterHandleNotificationSuccess(statsReporters []interfaces.StatsReporter, game string, platform string) { +func StatsReporterHandleNotificationSuccess(statsReporters []interfaces.StatsReporter, game string, platform string) { for _, statsReporter := range statsReporters { statsReporter.HandleNotificationSuccess(game, platform) } } -func statsReporterHandleNotificationFailure( +func StatsReporterHandleNotificationFailure( statsReporters []interfaces.StatsReporter, game string, platform string, @@ -88,19 +88,19 @@ func statsReporterHandleNotificationFailure( } } -func statsReporterNotificationRateLimitReached(statsReporters []interfaces.StatsReporter, game string, platform string) { +func StatsReporterNotificationRateLimitReached(statsReporters []interfaces.StatsReporter, game string, platform string) { for _, statsReporter := range statsReporters { statsReporter.NotificationRateLimitReached(game, platform) } } -func statsReporterNotificationRateLimitFailed(statsReporters []interfaces.StatsReporter, game string, platform string) { +func StatsReporterNotificationRateLimitFailed(statsReporters []interfaces.StatsReporter, game string, platform string) { for _, statsReporter := range statsReporters { statsReporter.NotificationRateLimitFailed(game, platform) } } -func statsReporterReportSendNotificationLatency(statsReporters []interfaces.StatsReporter, latencyMs time.Duration, game string, platform string, labels ...string) { +func StatsReporterReportSendNotificationLatency(statsReporters []interfaces.StatsReporter, latencyMs time.Duration, game string, platform string, labels ...string) { for _, statsReporter := range statsReporters { statsReporter.ReportSendNotificationLatency(latencyMs, game, platform, labels...) } diff --git a/extensions/common_test.go b/extensions/common_test.go index c6b0180..4c3ab38 100644 --- a/extensions/common_test.go +++ b/extensions/common_test.go @@ -63,7 +63,7 @@ var _ = Describe("Common", func() { Describe("Send feedback to reporters", func() { It("should return an error if res cannot be marshaled", func() { badContent := make(chan int) - err := sendToFeedbackReporters(feedbackClients, badContent, ParsedTopic{}) + err := SendToFeedbackReporters(feedbackClients, badContent, ParsedTopic{}) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(Equal("json: unsupported type: chan int")) }) diff --git a/extensions/client/errors.go b/extensions/firebase/client/errors.go similarity index 100% rename from extensions/client/errors.go rename to extensions/firebase/client/errors.go diff --git a/extensions/client/firebase.go b/extensions/firebase/client/firebase.go similarity index 100% rename from extensions/client/firebase.go rename to extensions/firebase/client/firebase.go diff --git a/extensions/handler/config.go b/extensions/firebase/config.go similarity index 94% rename from extensions/handler/config.go rename to extensions/firebase/config.go index 0969c2e..dab4750 100644 --- a/extensions/handler/config.go +++ b/extensions/firebase/config.go @@ -1,4 +1,4 @@ -package handler +package firebase import "time" diff --git a/extensions/handler/feedback_response.go b/extensions/firebase/feedback_response.go similarity index 94% rename from extensions/handler/feedback_response.go rename to extensions/firebase/feedback_response.go index f40fd1b..2029c38 100644 --- a/extensions/handler/feedback_response.go +++ b/extensions/firebase/feedback_response.go @@ -1,4 +1,4 @@ -package handler +package firebase // FeedbackResponse struct is sent to feedback reporters // in order to keep the format expected by it diff --git a/extensions/handler/message_handler.go b/extensions/firebase/message_handler.go similarity index 81% rename from extensions/handler/message_handler.go rename to extensions/firebase/message_handler.go index 71afdd2..4513c80 100644 --- a/extensions/handler/message_handler.go +++ b/extensions/firebase/message_handler.go @@ -1,4 +1,4 @@ -package handler +package firebase import ( "context" @@ -17,10 +17,9 @@ type messageHandler struct { logger *logrus.Logger client interfaces.PushClient config messageHandlerConfig - stats messagesStats - statsMutex sync.Mutex feedbackReporters []interfaces.FeedbackReporter statsReporters []interfaces.StatsReporter + pendingMessagesWaitGroup *sync.WaitGroup rateLimiter interfaces.RateLimiter statsDClient extensions.StatsD sendPushConcurrencyControl chan interface{} @@ -30,6 +29,12 @@ type messageHandler struct { } } +type kafkaFCMMessage struct { + interfaces.Message + Metadata map[string]interface{} `json:"metadata,omitempty"` + PushExpiry int64 `json:"push_expiry,omitempty"` +} + var _ interfaces.MessageHandler = &messageHandler{} func NewMessageHandler( @@ -38,6 +43,7 @@ func NewMessageHandler( feedbackReporters []interfaces.FeedbackReporter, statsReporters []interfaces.StatsReporter, rateLimiter interfaces.RateLimiter, + pendingMessagesWaitGroup *sync.WaitGroup, logger *logrus.Logger, concurrentWorkers int, ) interfaces.MessageHandler { @@ -54,6 +60,7 @@ func NewMessageHandler( feedbackReporters: feedbackReporters, statsReporters: statsReporters, rateLimiter: rateLimiter, + pendingMessagesWaitGroup: pendingMessagesWaitGroup, logger: l.Logger, config: cfg, sendPushConcurrencyControl: make(chan interface{}, concurrentWorkers), @@ -74,26 +81,24 @@ func (h *messageHandler) HandleMessages(ctx context.Context, msg interfaces.Kafk l := h.logger.WithFields(logrus.Fields{ "method": "HandleMessages", }) - km := extensions.KafkaGCMMessage{} + km := kafkaFCMMessage{} err := json.Unmarshal(msg.Value, &km) if err != nil { l.WithError(err).Error("Error unmarshalling message.") + h.waitGroupDone() return } if km.PushExpiry > 0 && km.PushExpiry < extensions.MakeTimestamp() { l.Warnf("ignoring push message because it has expired: %s", km.Data) - - h.statsMutex.Lock() - h.stats.ignored++ - h.statsMutex.Unlock() - + h.waitGroupDone() return } allowed := h.rateLimiter.Allow(ctx, km.To, msg.Game, "gcm") if !allowed { h.reportRateLimitReached(msg.Game) + h.waitGroupDone() l.WithField("message", msg).Warn("rate limit reached") return } @@ -146,42 +151,16 @@ func (h *messageHandler) HandleResponses() { for { response := <-h.responsesChannel if response.error != nil { - h.handleNotificationFailure(response.error) + h.handleNotificationFailure(response.msg, response.error) } else { h.handleNotificationAck() } + h.waitGroupDone() } }() } } -func (h *messageHandler) LogStats() { - l := h.logger.WithFields(logrus.Fields{ - "method": "logStats", - "interval(ns)": h.config.statusLogInterval.Nanoseconds(), - }) - - ticker := time.NewTicker(h.config.statusLogInterval) - for range ticker.C { - h.statsMutex.Lock() - if h.stats.sent > 0 || h.stats.ignored > 0 || h.stats.failures > 0 { - l.WithFields(logrus.Fields{ - "sentMessages": h.stats.sent, - "ignoredMessages": h.stats.ignored, - "failuresReceived": h.stats.failures, - }).Info("flushing stats") - - h.stats.sent = 0 - h.stats.ignored = 0 - h.stats.failures = 0 - } - h.statsMutex.Unlock() - } -} - -func (h *messageHandler) CleanMetadataCache() { -} - func (h *messageHandler) sendToFeedbackReporters(res interface{}) error { jsonRes, err := json.Marshal(res) if err != nil { @@ -205,19 +184,9 @@ func (h *messageHandler) handleNotificationAck() { for _, statsReporter := range h.statsReporters { statsReporter.HandleNotificationSuccess(h.app, "gcm") } - - for _, feedbackReporter := range h.feedbackReporters { - r := &FeedbackResponse{} - b, _ := json.Marshal(r) - feedbackReporter.SendFeedback(h.app, "gcm", b) - } - - h.statsMutex.Lock() - h.stats.sent++ - h.statsMutex.Unlock() } -func (h *messageHandler) handleNotificationFailure(err error) { +func (h *messageHandler) handleNotificationFailure(message interfaces.Message, err error) { pushError := translateToPushError(err) for _, statsReporter := range h.statsReporters { statsReporter.HandleNotificationFailure(h.app, "gcm", pushError) @@ -226,13 +195,11 @@ func (h *messageHandler) handleNotificationFailure(err error) { feedback := &FeedbackResponse{ Error: pushError.Key, ErrorDescription: pushError.Description, + From: message.To, } b, _ := json.Marshal(feedback) feedbackReporter.SendFeedback(h.app, "gcm", b) } - h.statsMutex.Lock() - h.stats.failures++ - h.statsMutex.Unlock() } func (h *messageHandler) reportLatency(latency time.Duration) { @@ -259,3 +226,9 @@ func translateToPushError(err error) *pushErrors.PushError { } return pushErrors.NewPushError("unknown", err.Error()) } + +func (h *messageHandler) waitGroupDone() { + if h.pendingMessagesWaitGroup != nil { + h.pendingMessagesWaitGroup.Done() + } +} diff --git a/extensions/firebase/message_handler_test.go b/extensions/firebase/message_handler_test.go new file mode 100644 index 0000000..23419c9 --- /dev/null +++ b/extensions/firebase/message_handler_test.go @@ -0,0 +1,410 @@ +package firebase + +import ( + "context" + "encoding/json" + "fmt" + "github.com/topfreegames/pusher/errors" + mock_interfaces "github.com/topfreegames/pusher/mocks/interfaces" + "os" + "sync" + "testing" + "time" + + "github.com/google/uuid" + "github.com/sirupsen/logrus/hooks/test" + "github.com/spf13/viper" + "github.com/stretchr/testify/suite" + "github.com/topfreegames/pusher/config" + "github.com/topfreegames/pusher/extensions" + "github.com/topfreegames/pusher/interfaces" + "go.uber.org/mock/gomock" +) + +const concurrentWorkers = 5 + +type MessageHandlerTestSuite struct { + suite.Suite + vConfig *viper.Viper + config *config.Config + game string + + mockClient *mock_interfaces.MockPushClient + mockStatsReporter *mock_interfaces.MockStatsReporter + mockFeedbackReporter *mock_interfaces.MockFeedbackReporter + mockRateLimiter *mock_interfaces.MockRateLimiter + waitGroup *sync.WaitGroup + + handler interfaces.MessageHandler +} + +func TestMessageHandlerSuite(t *testing.T) { + suite.Run(t, new(MessageHandlerTestSuite)) +} + +func (s *MessageHandlerTestSuite) SetupSuite() { + file := os.Getenv("CONFIG_FILE") + if file == "" { + file = "../../config/test.yaml" + } + + config, vConfig, err := config.NewConfigAndViper(file) + s.Require().NoError(err) + s.config = config + s.vConfig = vConfig + s.game = "game" +} + +func (s *MessageHandlerTestSuite) SetupSubTest() { + ctrl := gomock.NewController(s.T()) + s.mockClient = mock_interfaces.NewMockPushClient(ctrl) + + l, _ := test.NewNullLogger() + + s.mockStatsReporter = mock_interfaces.NewMockStatsReporter(ctrl) + s.mockFeedbackReporter = mock_interfaces.NewMockFeedbackReporter(ctrl) + statsClients := []interfaces.StatsReporter{s.mockStatsReporter} + feedbackClients := []interfaces.FeedbackReporter{s.mockFeedbackReporter} + s.mockRateLimiter = mock_interfaces.NewMockRateLimiter(ctrl) + s.waitGroup = &sync.WaitGroup{} + + cfg := newDefaultMessageHandlerConfig() + cfg.concurrentResponseHandlers = concurrentWorkers + handler := NewMessageHandler( + s.game, + s.mockClient, + feedbackClients, + statsClients, + s.mockRateLimiter, + s.waitGroup, + l, + concurrentWorkers, + ) + + s.handler = handler +} + +func (s *MessageHandlerTestSuite) TestHandleMessage() { + s.Run("should fail if invalid kafka message format", func() { + msg := interfaces.KafkaMessage{ + Topic: "push-game_gcm", + Value: []byte(`not json`), + } + + s.waitGroup.Add(1) + s.handler.HandleMessages(context.Background(), msg) + + waitWG(s.T(), s.waitGroup) + }) + + s.Run("should fail if notification expired", func() { + message := interfaces.Message{} + km := &kafkaFCMMessage{ + Message: message, + PushExpiry: extensions.MakeTimestamp() - time.Hour.Milliseconds(), + } + bytes, err := json.Marshal(km) + s.Require().NoError(err) + + s.waitGroup.Add(1) + s.handler.HandleMessages(context.Background(), interfaces.KafkaMessage{Value: bytes}) + + waitWG(s.T(), s.waitGroup) + }) + + s.Run("should fail if rate limit reached", func() { + expiration := time.Now().Add(1 * time.Hour).UnixNano() + token := uuid.NewString() + msg := interfaces.KafkaMessage{ + Topic: "push-game_gcm", + Game: s.game, + Value: []byte(fmt.Sprintf(`{"To": "%s", "Payload": { "aps" : { "alert" : "Hello HTTP/2" } }, "Metadata": { "expiration": 0 }, "push_expiry": %d }`, + token, + expiration, + )), + } + + s.mockRateLimiter.EXPECT(). + Allow(gomock.Any(), token, s.game, "gcm"). + Return(false) + + s.mockStatsReporter.EXPECT(). + NotificationRateLimitReached(s.game, "gcm"). + Return() + + s.waitGroup.Add(1) + s.handler.HandleMessages(context.Background(), msg) + waitWG(s.T(), s.waitGroup) + }) + + s.Run("should succeed", func() { + token := uuid.NewString() + msgValue := kafkaFCMMessage{ + Message: interfaces.Message{ + To: token, + Data: map[string]interface{}{ + "title": "notification", + "body": "body", + }, + }, + Metadata: map[string]interface{}{ + "some": "metadata", + }, + } + bytes, err := json.Marshal(msgValue) + msg := interfaces.KafkaMessage{Value: bytes, Topic: "push-game_gcm", Game: s.game} + s.Require().NoError(err) + + s.mockRateLimiter.EXPECT(). + Allow(gomock.Any(), token, s.game, "gcm"). + Return(true) + + done := make(chan struct{}) + + s.mockClient.EXPECT(). + SendPush(gomock.Any(), gomock.Any()). + Do(func(_ context.Context, msg interfaces.Message) { + s.Equal(token, msg.To) + done <- struct{}{} + }) + + s.mockStatsReporter.EXPECT(). + ReportSendNotificationLatency(gomock.Any(), s.game, "gcm", gomock.Any()).Return() + + s.mockStatsReporter.EXPECT(). + ReportFirebaseLatency(gomock.Any(), s.game, gomock.Any()).Return() + + s.mockStatsReporter.EXPECT(). + HandleNotificationSent(s.game, "gcm"). + Return() + + s.handler.HandleMessages(context.Background(), msg) + timeout := time.NewTimer(10 * time.Millisecond) + select { + case <-done: + case <-timeout.C: + s.Fail("timed out waiting for message to be sent") + } + }) + + s.Run("should not lock sendPushConcurrencyControl when sending multiple messages", func() { + newMessage := func() kafkaFCMMessage { + ttl := uint(0) + token := uuid.NewString() + title := fmt.Sprintf("title - %s", uuid.NewString()) + metadata := map[string]interface{}{ + "some": "metadata", + "game": "game", + "platform": "gcm", + } + km := kafkaFCMMessage{ + Message: interfaces.Message{ + TimeToLive: &ttl, + DeliveryReceiptRequested: false, + DryRun: true, + To: token, + Data: map[string]interface{}{ + "title": title, + }, + }, + Metadata: metadata, + PushExpiry: extensions.MakeTimestamp() + int64(1000000), + } + return km + } + go s.handler.HandleResponses() + qtyMsgs := 100 + + s.mockRateLimiter.EXPECT(). + Allow(gomock.Any(), gomock.Any(), s.game, "gcm"). + Return(true). + Times(qtyMsgs) + + s.mockClient.EXPECT(). + SendPush(gomock.Any(), gomock.Any()). + Return(nil). + Times(qtyMsgs) + + s.mockStatsReporter.EXPECT(). + ReportSendNotificationLatency(gomock.Any(), s.game, "gcm", gomock.Any()). + Times(qtyMsgs). + Return() + + s.mockStatsReporter.EXPECT(). + HandleNotificationSent(s.game, "gcm"). + Times(qtyMsgs). + Return() + + done := make(chan struct{}) + s.mockStatsReporter.EXPECT(). + HandleNotificationSuccess(s.game, "gcm"). + Times(qtyMsgs). + Do(func(game, platform string) { + done <- struct{}{} + }) + + s.mockStatsReporter.EXPECT(). + ReportFirebaseLatency(gomock.Any(), s.game, gomock.Any()).Return(). + Times(qtyMsgs) + + ctx := context.Background() + for i := 0; i < qtyMsgs; i++ { + km := newMessage() + bytes, err := json.Marshal(km) + s.Require().NoError(err) + s.waitGroup.Add(1) + go s.handler.HandleMessages(ctx, interfaces.KafkaMessage{Value: bytes, Game: s.game}) + } + + timeout := time.NewTimer(50 * time.Millisecond) + for i := 0; i < qtyMsgs; i++ { + select { + case <-done: + case <-timeout.C: + s.FailNow("timed out waiting for message to be sent") + } + } + }) +} + +func (s *MessageHandlerTestSuite) TestHandleResponse() { + s.Run("should send metric and feedback on failure", func() { + token := uuid.NewString() + msgValue := kafkaFCMMessage{ + Message: interfaces.Message{ + To: token, + Data: map[string]interface{}{ + "title": "notification", + "body": "body", + }, + }, + Metadata: map[string]interface{}{ + "some": "metadata", + }, + } + bytes, err := json.Marshal(msgValue) + msg := interfaces.KafkaMessage{Value: bytes, Topic: "push-game_gcm", Game: s.game} + s.Require().NoError(err) + + s.mockRateLimiter.EXPECT(). + Allow(gomock.Any(), token, s.game, "gcm"). + Return(true) + + done := make(chan struct{}) + + s.mockClient.EXPECT(). + SendPush(gomock.Any(), gomock.Any()). + Return(errors.NewPushError("DEVICE_UNREGISTERED", "device unregistered")) + + s.mockStatsReporter.EXPECT(). + ReportSendNotificationLatency(gomock.Any(), s.game, "gcm", gomock.Any()).Return() + + s.mockStatsReporter.EXPECT(). + ReportFirebaseLatency(gomock.Any(), s.game, gomock.Any()).Return() + + s.mockStatsReporter.EXPECT(). + HandleNotificationSent(s.game, "gcm"). + Return() + + s.mockStatsReporter.EXPECT(). + HandleNotificationFailure(s.game, "gcm", gomock.Any()) + + s.mockFeedbackReporter.EXPECT(). + SendFeedback(s.game, "gcm", gomock.Any()). + DoAndReturn(func(game, platform string, feedback []byte) { + obj := &FeedbackResponse{} + err := json.Unmarshal(feedback, obj) + s.NoError(err) + s.Equal(token, obj.From) + done <- struct{}{} + }) + + go s.handler.HandleResponses() + + s.waitGroup.Add(1) + + s.handler.HandleMessages(context.Background(), msg) + + timeout := time.NewTimer(10 * time.Millisecond) + select { + case <-done: + case <-timeout.C: + s.Fail("timed out waiting for message to be sent") + } + }) + + s.Run("should send ack metric on success", func() { + token := uuid.NewString() + msgValue := kafkaFCMMessage{ + Message: interfaces.Message{ + To: token, + Data: map[string]interface{}{ + "title": "notification", + "body": "body", + }, + }, + Metadata: map[string]interface{}{ + "some": "metadata", + }, + } + bytes, err := json.Marshal(msgValue) + msg := interfaces.KafkaMessage{Value: bytes, Topic: "push-game_gcm", Game: s.game} + s.Require().NoError(err) + + s.mockRateLimiter.EXPECT(). + Allow(gomock.Any(), token, s.game, "gcm"). + Return(true) + + done := make(chan struct{}) + + s.mockClient.EXPECT(). + SendPush(gomock.Any(), gomock.Any()). + Do(func(_ context.Context, msg interfaces.Message) { + s.Equal(token, msg.To) + }) + + s.mockStatsReporter.EXPECT(). + ReportSendNotificationLatency(gomock.Any(), s.game, "gcm", gomock.Any()).Return() + + s.mockStatsReporter.EXPECT(). + ReportFirebaseLatency(gomock.Any(), s.game, gomock.Any()).Return() + + s.mockStatsReporter.EXPECT(). + HandleNotificationSent(s.game, "gcm"). + Return() + + s.mockStatsReporter.EXPECT(). + HandleNotificationSuccess(s.game, "gcm"). + Do(func(game, platform string) { + done <- struct{}{} + }) + + go s.handler.HandleResponses() + + s.waitGroup.Add(1) + + s.handler.HandleMessages(context.Background(), msg) + + timeout := time.NewTimer(10 * time.Millisecond) + select { + case <-done: + case <-timeout.C: + s.Fail("timed out waiting for message to be sent") + } + }) +} + +func waitWG(t *testing.T, wg *sync.WaitGroup) { + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + timeout := time.After(10 * time.Millisecond) + select { + case <-done: + case <-timeout: + t.Fatal("timed out waiting for waitgroup") + } +} diff --git a/extensions/handler/stats.go b/extensions/firebase/stats.go similarity index 82% rename from extensions/handler/stats.go rename to extensions/firebase/stats.go index 7dacac8..8feb29d 100644 --- a/extensions/handler/stats.go +++ b/extensions/firebase/stats.go @@ -1,4 +1,4 @@ -package handler +package firebase type messagesStats struct { sent int64 diff --git a/extensions/gcm_message_handler.go b/extensions/gcm_message_handler.go deleted file mode 100644 index 822a73c..0000000 --- a/extensions/gcm_message_handler.go +++ /dev/null @@ -1,504 +0,0 @@ -/* - * Copyright (c) 2016 TFG Co - * Author: TFG Co - * - * Permission is hereby granted, free of charge, to any person obtaining a copy of - * this software and associated documentation files (the "Software"), to deal in - * the Software without restriction, including without limitation the rights to - * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of - * the Software, and to permit persons to whom the Software is furnished to do so, - * subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS - * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR - * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER - * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN - * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - */ - -package extensions - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "os" - "strings" - "sync" - "time" - - "github.com/sirupsen/logrus" - "github.com/spf13/viper" - "github.com/topfreegames/go-gcm" - pushererrors "github.com/topfreegames/pusher/errors" - "github.com/topfreegames/pusher/interfaces" -) - -var gcmResMutex sync.Mutex - -// KafkaGCMMessage is a enriched XMPPMessage with a Metadata field -type KafkaGCMMessage struct { - interfaces.Message - Metadata map[string]interface{} `json:"metadata,omitempty"` - PushExpiry int64 `json:"push_expiry,omitempty"` -} - -// CCSMessageWithMetadata is an enriched CCSMessage with a metadata field -type CCSMessageWithMetadata struct { - gcm.CCSMessage - Timestamp int64 `json:"timestamp"` - Metadata map[string]interface{} `json:"metadata,omitempty"` -} - -// GCMMessageHandler implements the MessageHandler interface -type GCMMessageHandler struct { - feedbackReporters []interfaces.FeedbackReporter - StatsReporters []interfaces.StatsReporter - game string - GCMClient interfaces.GCMClient - ViperConfig *viper.Viper - failuresReceived int64 - InflightMessagesMetadata map[string]interface{} - Logger *logrus.Entry - LogStatsInterval time.Duration - pendingMessages chan bool - pendingMessagesWG *sync.WaitGroup - ignoredMessages int64 - inflightMessagesMetadataLock *sync.Mutex - responsesReceived int64 - sentMessages int64 - successesReceived int64 - requestsHeap *TimeoutHeap - CacheCleaningInterval int - IsProduction bool - rateLimiter interfaces.RateLimiter -} - -// NewGCMMessageHandler returns a new instance of a GCMMessageHandler -func NewGCMMessageHandler( - game string, - isProduction bool, - config *viper.Viper, - logger *logrus.Logger, - pendingMessagesWG *sync.WaitGroup, - statsReporters []interfaces.StatsReporter, - feedbackReporters []interfaces.FeedbackReporter, - rateLimiter interfaces.RateLimiter, -) (*GCMMessageHandler, error) { - l := logger.WithFields(logrus.Fields{ - "method": "NewGCMMessageHandler", - "game": game, - "isProduction": isProduction, - }) - - h, err := NewGCMMessageHandlerWithClient(game, isProduction, config, l.Logger, pendingMessagesWG, statsReporters, feedbackReporters, nil, rateLimiter) - if err != nil { - l.WithError(err).Error("Failed to create a new GCM Message handler.") - return nil, err - } - return h, nil -} - -func NewGCMMessageHandlerWithClient( - game string, - isProduction bool, - config *viper.Viper, - logger *logrus.Logger, - pendingMessagesWG *sync.WaitGroup, - statsReporters []interfaces.StatsReporter, - feedbackReporters []interfaces.FeedbackReporter, - client interfaces.GCMClient, - rateLimiter interfaces.RateLimiter, -) (*GCMMessageHandler, error) { - l := logger.WithFields(logrus.Fields{ - "method": "NewGCMMessageHandlerWithClient", - "game": game, - "isProduction": isProduction, - }) - - g := &GCMMessageHandler{ - ViperConfig: config, - failuresReceived: 0, - feedbackReporters: feedbackReporters, - game: game, - InflightMessagesMetadata: map[string]interface{}{}, - inflightMessagesMetadataLock: &sync.Mutex{}, - IsProduction: isProduction, - Logger: l, - pendingMessagesWG: pendingMessagesWG, - requestsHeap: NewTimeoutHeap(config), - StatsReporters: statsReporters, - GCMClient: client, - rateLimiter: rateLimiter, - } - - err := g.configure() - if err != nil { - l.WithError(err).Error("Failed to create a new GCM Message handler.") - return nil, err - } - return g, nil -} - -func (g *GCMMessageHandler) configure() error { - g.loadConfigurationDefaults() - - g.pendingMessages = make(chan bool, g.ViperConfig.GetInt("gcm.maxPendingMessages")) - interval := g.ViperConfig.GetInt("gcm.logStatsInterval") - g.LogStatsInterval = time.Duration(interval) * time.Millisecond - g.CacheCleaningInterval = g.ViperConfig.GetInt("feedback.cache.cleaningInterval") - - if g.GCMClient == nil { // Configures the legacy GCM client here because it needs the handleGCMResponse function - err := g.configureGCMClient() - if err != nil { - return err - } - } - - return nil -} - -func (g *GCMMessageHandler) loadConfigurationDefaults() { - g.ViperConfig.SetDefault("gcm.pingInterval", 20) - g.ViperConfig.SetDefault("gcm.pingTimeout", 30) - g.ViperConfig.SetDefault("gcm.maxPendingMessages", 100) - g.ViperConfig.SetDefault("gcm.logStatsInterval", 5000) - g.ViperConfig.SetDefault("gcm.client.initialization.retries", 3) - g.ViperConfig.SetDefault("feedback.cache.cleaningInterval", 300000) -} - -func (g *GCMMessageHandler) configureGCMClient() error { - l := g.Logger.WithField("method", "configureGCMClient") - - senderID := g.ViperConfig.GetString(fmt.Sprintf("gcm.certs.%s.senderID", g.game)) - apiKey := g.ViperConfig.GetString(fmt.Sprintf("gcm.certs.%s.apiKey", g.game)) - if senderID == "" || apiKey == "" { - l.Error("senderID or apiKey not found") - return errors.New("senderID or apiKey not found") - } - - gcmConfig := &gcm.Config{ - SenderID: senderID, - APIKey: apiKey, - Sandbox: !g.IsProduction, - MonitorConnection: true, - Debug: false, - } - - var err error - var cl interfaces.GCMClient - for retries := g.ViperConfig.GetInt("gcm.client.initialization.retries"); retries > 0; retries-- { - cl, err = gcm.NewClient(gcmConfig, g.handleGCMResponse) - if err != nil && retries-1 != 0 { - l.WithError(err).Warnf("failed to create gcm client. %d attempts left.", retries-1) - } else { - break - } - } - if err != nil { - l.WithError(err).Error("failed to create gcm client.") - return err - } - g.GCMClient = cl - return nil -} - -// WARNING: Be careful, code here needs to be thread safe! -func (g *GCMMessageHandler) handleGCMResponse(cm gcm.CCSMessage) error { - defer func() { - if g.pendingMessagesWG != nil { - g.pendingMessagesWG.Done() - } - }() - - l := g.Logger.WithFields(logrus.Fields{ - "method": "handleGCMResponse", - "ccsMessage": cm, - }) - l.Debug("Got response from gcm.") - gcmResMutex.Lock() - - select { - case <-g.pendingMessages: - l.Debug("Freeing pendingMessages channel") - default: - l.Warn("No pending messages in channel but received response.") - } - - g.responsesReceived++ - gcmResMutex.Unlock() - - var err error - ccsMessageWithMetadata := &CCSMessageWithMetadata{ - CCSMessage: cm, - } - parsedTopic := ParsedTopic{} - g.inflightMessagesMetadataLock.Lock() - if val, ok := g.InflightMessagesMetadata[cm.MessageID]; ok { - ccsMessageWithMetadata.Metadata = val.(map[string]interface{}) - ccsMessageWithMetadata.Timestamp = ccsMessageWithMetadata.Metadata["timestamp"].(int64) - parsedTopic.Game = ccsMessageWithMetadata.Metadata["game"].(string) - parsedTopic.Platform = ccsMessageWithMetadata.Metadata["platform"].(string) - delete(ccsMessageWithMetadata.Metadata, "timestamp") - delete(g.InflightMessagesMetadata, cm.MessageID) - } - g.inflightMessagesMetadataLock.Unlock() - - if cm.Error != "" { - gcmResMutex.Lock() - g.failuresReceived++ - gcmResMutex.Unlock() - pErr := pushererrors.NewPushError(strings.ToLower(cm.Error), cm.ErrorDescription) - statsReporterHandleNotificationFailure(g.StatsReporters, parsedTopic.Game, "gcm", pErr) - - err = pErr - switch cm.Error { - // errors from https://developers.google.com/cloud-messaging/xmpp-server-ref table 4 - case "DEVICE_UNREGISTERED", "BAD_REGISTRATION": - l.WithFields(logrus.Fields{ - "category": "TokenError", - logrus.ErrorKey: fmt.Errorf("%s (Description: %s)", cm.Error, cm.ErrorDescription), - }).Debug("received an error") - if ccsMessageWithMetadata.Metadata != nil { - ccsMessageWithMetadata.Metadata["deleteToken"] = true - } - case "INVALID_JSON": - l.WithFields(logrus.Fields{ - "category": "JsonError", - logrus.ErrorKey: fmt.Errorf("%s (Description: %s)", cm.Error, cm.ErrorDescription), - }).Debug("received an error") - case "SERVICE_UNAVAILABLE", "INTERNAL_SERVER_ERROR": - l.WithFields(logrus.Fields{ - "category": "GoogleError", - logrus.ErrorKey: cm.Error, - }).Debug("received an error") - case "DEVICE_MESSAGE_RATE_EXCEEDED", "TOPICS_MESSAGE_RATE_EXCEEDED": - l.WithFields(logrus.Fields{ - "category": "RateExceededError", - logrus.ErrorKey: cm.Error, - }).Debug("received an error") - case "CONNECTION_DRAINING": - l.WithFields(logrus.Fields{ - "category": "ConnectionDrainingError", - logrus.ErrorKey: cm.Error, - }).Debug("received an error") - default: - l.WithFields(logrus.Fields{ - "category": "DefaultError", - logrus.ErrorKey: cm.Error, - }).Debug("received an error") - } - sendFeedbackErr := sendToFeedbackReporters(g.feedbackReporters, ccsMessageWithMetadata, parsedTopic) - if sendFeedbackErr != nil { - l.WithError(sendFeedbackErr).Error("error sending feedback to reporter") - } - return err - } - - sendFeedbackErr := sendToFeedbackReporters(g.feedbackReporters, ccsMessageWithMetadata, parsedTopic) - if sendFeedbackErr != nil { - l.WithError(sendFeedbackErr).Error("error sending feedback to reporter") - } - - gcmResMutex.Lock() - g.successesReceived++ - gcmResMutex.Unlock() - statsReporterHandleNotificationSuccess(g.StatsReporters, parsedTopic.Game, "gcm") - - return nil -} - -func (g *GCMMessageHandler) sendMessage(message interfaces.KafkaMessage) error { - l := g.Logger.WithField("method", "sendMessage") - //ttl := uint(0) - km := KafkaGCMMessage{} - err := json.Unmarshal(message.Value, &km) - if err != nil { - l.WithError(err).Error("Error unmarshalling message.") - return err - } - if km.PushExpiry > 0 && km.PushExpiry < MakeTimestamp() { - l.Warnf("ignoring push message because it has expired: %s", km.Data) - g.ignoredMessages++ - if g.pendingMessagesWG != nil { - g.pendingMessagesWG.Done() - } - return nil - } - - if km.Metadata != nil { - if km.Message.Data == nil { - km.Message.Data = map[string]interface{}{} - } - - for k, v := range km.Metadata { - if km.Message.Data[k] == nil { - km.Message.Data[k] = v - } - } - } - - l = l.WithField("message", km) - - allowed := g.rateLimiter.Allow(context.Background(), km.To, message.Game, "gcm") - if !allowed { - statsReporterNotificationRateLimitReached(g.StatsReporters, message.Game, "gcm") - l.WithField("message", message).Warn("rate limit reached") - return errors.New("rate limit reached") - } - l.Debug("sending message to gcm") - - var messageID string - var bytes int - - g.pendingMessages <- true - xmppMessage := toGCMMessage(km.Message) - - before := time.Now() - messageID, bytes, err = g.GCMClient.SendXMPP(xmppMessage) - elapsed := time.Since(before) - statsReporterReportSendNotificationLatency(g.StatsReporters, elapsed, g.game, "gcm", "client", "gcm") - - if err != nil { - <-g.pendingMessages - l.WithError(err).Error("Error sending message.") - return err - } - - if messageID != "" { - if km.Metadata == nil { - km.Metadata = map[string]interface{}{} - } - - km.Metadata["timestamp"] = time.Now().Unix() - hostname, err := os.Hostname() - if err != nil { - l.WithError(err).Error("error retrieving hostname") - } else { - km.Metadata["hostname"] = hostname - } - - km.Metadata["game"] = message.Game - km.Metadata["platform"] = "gcm" - - g.inflightMessagesMetadataLock.Lock() - g.InflightMessagesMetadata[messageID] = km.Metadata - g.requestsHeap.AddRequest(messageID) - g.inflightMessagesMetadataLock.Unlock() - } - - statsReporterHandleNotificationSent(g.StatsReporters, message.Game, "gcm") - - gcmResMutex.Lock() - g.sentMessages++ - gcmResMutex.Unlock() - - l.WithFields(logrus.Fields{ - "messageID": messageID, - "bytes": bytes, - }).Debug("sent message") - - return nil -} - -func toGCMMessage(message interfaces.Message) gcm.XMPPMessage { - gcmMessage := gcm.XMPPMessage{ - To: message.To, - MessageID: message.MessageID, - MessageType: message.MessageType, - CollapseKey: message.CollapseKey, - Priority: message.Priority, - ContentAvailable: message.ContentAvailable, - TimeToLive: message.TimeToLive, - DeliveryReceiptRequested: message.DeliveryReceiptRequested, - DryRun: message.DryRun, - Data: gcm.Data(message.Data), - } - - if message.Notification != nil { - gcmMessage.Notification = &gcm.Notification{ - Title: message.Notification.Title, - Body: message.Notification.Body, - Sound: message.Notification.Sound, - ClickAction: message.Notification.ClickAction, - BodyLocKey: message.Notification.BodyLocKey, - BodyLocArgs: message.Notification.BodyLocArgs, - TitleLocKey: message.Notification.TitleLocKey, - TitleLocArgs: message.Notification.TitleLocArgs, - Icon: message.Notification.Icon, - Tag: message.Notification.Tag, - Color: message.Notification.Color, - Badge: message.Notification.Badge, - } - } - - return gcmMessage -} - -// HandleResponses from gcm -func (g *GCMMessageHandler) HandleResponses() { -} - -// CleanMetadataCache clears cache after timeout -func (g *GCMMessageHandler) CleanMetadataCache() { - var deviceToken string - var hasIndeed bool - for { - g.inflightMessagesMetadataLock.Lock() - for deviceToken, hasIndeed = g.requestsHeap.HasExpiredRequest(); hasIndeed; { - delete(g.InflightMessagesMetadata, deviceToken) - deviceToken, hasIndeed = g.requestsHeap.HasExpiredRequest() - } - g.inflightMessagesMetadataLock.Unlock() - - duration := time.Duration(g.CacheCleaningInterval) - time.Sleep(duration * time.Millisecond) - } -} - -// HandleMessages get messages from msgChan and send to GCM -func (g *GCMMessageHandler) HandleMessages(_ context.Context, msg interfaces.KafkaMessage) { - _ = g.sendMessage(msg) -} - -// LogStats from time to time -func (g *GCMMessageHandler) LogStats() { - l := g.Logger.WithFields(logrus.Fields{ - "method": "gcmMessageHandler.logStats", - "interval(ns)": g.LogStatsInterval, - }) - - ticker := time.NewTicker(g.LogStatsInterval) - for range ticker.C { - gcmResMutex.Lock() - if g.sentMessages > 0 || g.responsesReceived > 0 || g.ignoredMessages > 0 || g.successesReceived > 0 || g.failuresReceived > 0 { - l.WithFields(logrus.Fields{ - "sentMessages": g.sentMessages, - "responsesReceived": g.responsesReceived, - "ignoredMessages": g.ignoredMessages, - "successesReceived": g.successesReceived, - "failuresReceived": g.failuresReceived, - }).Info("flushing stats") - g.sentMessages = 0 - g.responsesReceived = 0 - g.successesReceived = 0 - g.ignoredMessages = 0 - g.failuresReceived = 0 - } - gcmResMutex.Unlock() - } -} - -// Cleanup closes connections to GCM -func (g *GCMMessageHandler) Cleanup() error { - err := g.GCMClient.Close() - if err != nil { - return err - } - return nil -} diff --git a/extensions/gcm_message_handler_test.go b/extensions/gcm_message_handler_test.go deleted file mode 100644 index 3e14321..0000000 --- a/extensions/gcm_message_handler_test.go +++ /dev/null @@ -1,874 +0,0 @@ -/* - * Copyright (c) 2016 TFG Co - * Author: TFG Co - * - * Permission is hereby granted, free of charge, to any person obtaining a copy of - * this software and associated documentation files (the "Software"), to deal in - * the Software without restriction, including without limitation the rights to - * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of - * the Software, and to permit persons to whom the Software is furnished to do so, - * subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS - * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR - * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER - * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN - * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - */ - -package extensions - -import ( - "encoding/json" - "os" - "testing" - "time" - - "github.com/sirupsen/logrus" - "github.com/spf13/viper" - "github.com/stretchr/testify/suite" - "github.com/topfreegames/pusher/config" - - uuid "github.com/satori/go.uuid" - "github.com/sirupsen/logrus/hooks/test" - "github.com/topfreegames/go-gcm" - "github.com/topfreegames/pusher/interfaces" - "github.com/topfreegames/pusher/mocks" -) - -type GCMMessageHandlerTestSuite struct { - suite.Suite - - config *config.Config - vConfig *viper.Viper - game string -} - -func TestGCMMessageHandlerSuite(t *testing.T) { - suite.Run(t, new(GCMMessageHandlerTestSuite)) -} - -func (s *GCMMessageHandlerTestSuite) SetupSuite() { - configFile := os.Getenv("CONFIG_FILE") - if configFile == "" { - configFile = "../config/test.yaml" - } - c, vConfig, err := config.NewConfigAndViper(configFile) - s.Require().NoError(err) - - s.config = c - s.vConfig = vConfig - s.game = "game" -} - -func (s *GCMMessageHandlerTestSuite) setupHandler() ( - *GCMMessageHandler, - *mocks.GCMClientMock, - *mocks.StatsDClientMock, - *mocks.KafkaProducerClientMock, -) { - logger, _ := test.NewNullLogger() - mockClient := mocks.NewGCMClientMock() - mockStatsdClient := mocks.NewStatsDClientMock() - mockRateLimiter := mocks.NewRateLimiterMock() - - statsD, err := NewStatsD(s.vConfig, logger, mockStatsdClient) - s.Require().NoError(err) - - mockKafkaProducer := mocks.NewKafkaProducerClientMock() - kc, err := NewKafkaProducer(s.vConfig, logger, mockKafkaProducer) - s.Require().NoError(err) - - statsClients := []interfaces.StatsReporter{statsD} - feedbackClients := []interfaces.FeedbackReporter{kc} - handler, err := NewGCMMessageHandlerWithClient( - s.game, - false, - s.vConfig, - logger, - nil, - statsClients, - feedbackClients, - mockClient, - mockRateLimiter, - ) - s.NoError(err) - s.Require().NotNil(handler) - s.Equal(s.game, handler.game) - s.NotNil(handler.ViperConfig) - s.False(handler.IsProduction) - s.Equal(int64(0), handler.responsesReceived) - s.Equal(int64(0), handler.sentMessages) - s.Len(mockClient.MessagesSent, 0) - - return handler, mockClient, mockStatsdClient, mockKafkaProducer -} - -func (s *GCMMessageHandlerTestSuite) TestConfigureHandler() { - s.Run("should fail if invalid credentials", func() { - handler, err := NewGCMMessageHandler( - s.game, - false, - s.vConfig, - logrus.New(), - nil, - []interfaces.StatsReporter{}, - []interfaces.FeedbackReporter{}, - nil, - ) - s.Error(err) - s.Nil(handler) - s.Equal("error connecting gcm xmpp client: auth failure: not-authorized", err.Error()) - }) -} - -func (s *GCMMessageHandlerTestSuite) TestHandleGCMResponse() { - s.Run("should succeed if response has no error", func() { - handler, _, _, mockKafkaProducer := s.setupHandler() - res := gcm.CCSMessage{} - mockKafkaProducer.StartConsumingMessagesInProduceChannel() - err := handler.handleGCMResponse(res) - s.NoError(err) - s.Equal(int64(1), handler.responsesReceived) - s.Equal(int64(1), handler.successesReceived) - }) - - s.Run("if response has error DEVICE_UNREGISTERED", func() { - handler, _, _, mockKafkaProducer := s.setupHandler() - res := gcm.CCSMessage{ - Error: "DEVICE_UNREGISTERED", - } - mockKafkaProducer.StartConsumingMessagesInProduceChannel() - err := handler.handleGCMResponse(res) - s.Error(err) - s.Equal(int64(1), handler.responsesReceived) - s.Equal(int64(1), handler.failuresReceived) - }) - - s.Run("if response has error BAD_REGISTRATION", func() { - handler, _, _, mockKafkaProducer := s.setupHandler() - res := gcm.CCSMessage{ - Error: "BAD_REGISTRATION", - } - mockKafkaProducer.StartConsumingMessagesInProduceChannel() - err := handler.handleGCMResponse(res) - s.Error(err) - s.Equal(int64(1), handler.responsesReceived) - s.Equal(int64(1), handler.failuresReceived) - }) - - s.Run("if response has error INVALID_JSON", func() { - handler, _, _, mockKafkaProducer := s.setupHandler() - res := gcm.CCSMessage{ - Error: "INVALID_JSON", - } - mockKafkaProducer.StartConsumingMessagesInProduceChannel() - err := handler.handleGCMResponse(res) - s.Error(err) - s.Equal(int64(1), handler.responsesReceived) - s.Equal(int64(1), handler.failuresReceived) - }) - - s.Run("if response has error SERVICE_UNAVAILABLE", func() { - handler, _, _, mockKafkaProducer := s.setupHandler() - res := gcm.CCSMessage{ - Error: "SERVICE_UNAVAILABLE", - } - mockKafkaProducer.StartConsumingMessagesInProduceChannel() - err := handler.handleGCMResponse(res) - s.Error(err) - s.Equal(int64(1), handler.responsesReceived) - s.Equal(int64(1), handler.failuresReceived) - }) - - s.Run("if response has error INTERNAL_SERVER_ERROR", func() { - handler, _, _, mockKafkaProducer := s.setupHandler() - res := gcm.CCSMessage{ - Error: "INTERNAL_SERVER_ERROR", - } - mockKafkaProducer.StartConsumingMessagesInProduceChannel() - err := handler.handleGCMResponse(res) - s.Error(err) - s.Equal(int64(1), handler.responsesReceived) - s.Equal(int64(1), handler.failuresReceived) - }) - - s.Run("if response has error DEVICE_MESSAGE_RATE_EXCEEDED", func() { - handler, _, _, mockKafkaProducer := s.setupHandler() - res := gcm.CCSMessage{ - Error: "DEVICE_MESSAGE_RATE_EXCEEDED", - } - mockKafkaProducer.StartConsumingMessagesInProduceChannel() - err := handler.handleGCMResponse(res) - s.Error(err) - s.Equal(int64(1), handler.responsesReceived) - s.Equal(int64(1), handler.failuresReceived) - }) - - s.Run("if response has error TOPICS_MESSAGE_RATE_EXCEEDED", func() { - handler, _, _, mockKafkaProducer := s.setupHandler() - res := gcm.CCSMessage{ - Error: "TOPICS_MESSAGE_RATE_EXCEEDED", - } - mockKafkaProducer.StartConsumingMessagesInProduceChannel() - err := handler.handleGCMResponse(res) - s.Error(err) - s.Equal(int64(1), handler.responsesReceived) - s.Equal(int64(1), handler.failuresReceived) - }) - - s.Run("if response has untracked error", func() { - handler, _, _, mockKafkaProducer := s.setupHandler() - res := gcm.CCSMessage{ - Error: "BAD_ACK", - } - mockKafkaProducer.StartConsumingMessagesInProduceChannel() - err := handler.handleGCMResponse(res) - s.Error(err) - s.Equal(int64(1), handler.responsesReceived) - s.Equal(int64(1), handler.failuresReceived) - }) -} - -func (s *GCMMessageHandlerTestSuite) TestSendMessage() { - s.Run("should not send message if expire is in the past", func() { - handler, _, _, _ := s.setupHandler() - ttl := uint(0) - metadata := map[string]interface{}{ - "some": "metadata", - "timestamp": time.Now().Unix(), - "game": "game", - "platform": "gcm", - } - msg := &KafkaGCMMessage{ - Message: interfaces.Message{ - TimeToLive: &ttl, - DeliveryReceiptRequested: false, - DryRun: true, - To: uuid.NewV4().String(), - Data: map[string]interface{}{ - "title": "notification", - }, - }, - Metadata: metadata, - PushExpiry: MakeTimestamp() - int64(100), - } - msgBytes, err := json.Marshal(msg) - s.Require().NoError(err) - - err = handler.sendMessage(interfaces.KafkaMessage{ - Topic: "push-game_gcm", - Value: msgBytes, - }) - s.Require().NoError(err) - s.Equal(int64(0), handler.sentMessages) - s.Equal(int64(1), handler.ignoredMessages) - }) - - s.Run("should send message if PushExpiry is in the future", func() { - handler, _, _, _ := s.setupHandler() - ttl := uint(0) - metadata := map[string]interface{}{ - "some": "metadata", - "timestamp": time.Now().Unix(), - "game": "game", - "platform": "gcm", - } - msg := &KafkaGCMMessage{ - Message: interfaces.Message{ - TimeToLive: &ttl, - DeliveryReceiptRequested: false, - DryRun: true, - To: uuid.NewV4().String(), - Data: map[string]interface{}{ - "title": "notification", - }, - }, - Metadata: metadata, - PushExpiry: MakeTimestamp() + int64(1000000), - } - msgBytes, err := json.Marshal(msg) - s.Require().NoError(err) - - err = handler.sendMessage(interfaces.KafkaMessage{ - Topic: "push-game_gcm", - Value: msgBytes, - }) - s.Require().NoError(err) - gcmResMutex.Lock() - s.Equal(int64(1), handler.sentMessages) - s.Equal(int64(0), handler.ignoredMessages) - gcmResMutex.Unlock() - }) - - s.Run("should send message and not increment sentMessages if an error occurs", func() { - handler, mockClient, _, _ := s.setupHandler() - err := handler.sendMessage(interfaces.KafkaMessage{ - Topic: "push-game_gcm", - Value: []byte("value"), - }) - s.Require().Error(err) - gcmResMutex.Lock() - s.Equal(int64(0), handler.sentMessages) - s.Len(handler.pendingMessages, 0) - gcmResMutex.Unlock() - s.Len(mockClient.MessagesSent, 0) - }) - - s.Run("should send xmpp message", func() { - handler, mockClient, _, _ := s.setupHandler() - ttl := uint(0) - msg := &interfaces.Message{ - TimeToLive: &ttl, - DeliveryReceiptRequested: false, - DryRun: true, - To: uuid.NewV4().String(), - Data: map[string]interface{}{ - "title": "notification", - }, - } - msgBytes, err := json.Marshal(msg) - s.Require().NoError(err) - - err = handler.sendMessage(interfaces.KafkaMessage{ - Topic: "push-game_gcm", - Value: msgBytes, - }) - s.Require().NoError(err) - gcmResMutex.Lock() - s.Equal(int64(1), handler.sentMessages) - s.Len(mockClient.MessagesSent, 1) - s.Len(handler.pendingMessages, 1) - gcmResMutex.Unlock() - }) - - s.Run("should send xmpp message with metadata", func() { - handler, mockClient, _, _ := s.setupHandler() - ttl := uint(0) - metadata := map[string]interface{}{ - "some": "metadata", - "timestamp": time.Now().Unix(), - "game": "game", - "platform": "gcm", - } - msg := &KafkaGCMMessage{ - Message: interfaces.Message{ - TimeToLive: &ttl, - DeliveryReceiptRequested: false, - DryRun: true, - To: uuid.NewV4().String(), - Data: map[string]interface{}{ - "title": "notification", - }, - }, - Metadata: metadata, - PushExpiry: MakeTimestamp() + int64(1000000), - } - msgBytes, err := json.Marshal(msg) - s.Require().NoError(err) - - err = handler.sendMessage(interfaces.KafkaMessage{ - Topic: "push-game_gcm", - Value: msgBytes, - }) - s.Require().NoError(err) - gcmResMutex.Lock() - s.Equal(int64(1), handler.sentMessages) - s.Len(mockClient.MessagesSent, 1) - s.Len(handler.pendingMessages, 1) - gcmResMutex.Unlock() - }) - - s.Run("should forward metadata content on GCM request", func() { - handler, mockClient, _, _ := s.setupHandler() - ttl := uint(0) - metadata := map[string]interface{}{ - "some": "metadata", - } - msg := &KafkaGCMMessage{ - Message: interfaces.Message{ - TimeToLive: &ttl, - DeliveryReceiptRequested: false, - DryRun: true, - To: uuid.NewV4().String(), - Data: map[string]interface{}{ - "title": "notification", - }, - }, - Metadata: metadata, - PushExpiry: MakeTimestamp() + int64(1000000), - } - msgBytes, err := json.Marshal(msg) - s.Require().NoError(err) - - err = handler.sendMessage(interfaces.KafkaMessage{ - Topic: "push-game_gcm", - Value: msgBytes, - }) - - s.Require().NoError(err) - gcmResMutex.Lock() - s.Equal(int64(1), handler.sentMessages) - s.Len(mockClient.MessagesSent, 1) - s.Len(handler.pendingMessages, 1) - gcmResMutex.Unlock() - - sentMessage := mockClient.MessagesSent[0] - s.NotNil(sentMessage) - s.Equal("metadata", sentMessage.Data["some"]) - }) - - s.Run("should forward nested metadata content on GCM request", func() { - handler, mockClient, _, _ := s.setupHandler() - ttl := uint(0) - metadata := map[string]interface{}{ - "some": "metadata", - } - msg := &KafkaGCMMessage{ - Message: interfaces.Message{ - TimeToLive: &ttl, - DeliveryReceiptRequested: false, - DryRun: true, - To: uuid.NewV4().String(), - Data: map[string]interface{}{ - "nested": map[string]interface{}{ - "some": "data", - }, - }, - }, - Metadata: metadata, - PushExpiry: MakeTimestamp() + int64(1000000), - } - msgBytes, err := json.Marshal(msg) - s.Require().NoError(err) - - err = handler.sendMessage(interfaces.KafkaMessage{ - Topic: "push-game_gcm", - Value: msgBytes, - }) - - s.Require().NoError(err) - gcmResMutex.Lock() - s.Equal(int64(1), handler.sentMessages) - s.Len(mockClient.MessagesSent, 1) - s.Len(handler.pendingMessages, 1) - gcmResMutex.Unlock() - - sentMessage := mockClient.MessagesSent[0] - s.NotNil(sentMessage) - s.Equal("metadata", sentMessage.Data["some"]) - s.Len(sentMessage.Data["nested"], 1) - s.Equal("data", sentMessage.Data["nested"].(map[string]interface{})["some"]) - }) - - s.Run("should wait to send message if maxPendingMessages is reached", func() { - handler, _, _, _ := s.setupHandler() - ttl := uint(0) - msg := &gcm.XMPPMessage{ - TimeToLive: &ttl, - DeliveryReceiptRequested: false, - DryRun: true, - To: uuid.NewV4().String(), - Data: map[string]interface{}{}, - } - msgBytes, err := json.Marshal(msg) - s.NoError(err) - - for i := 1; i <= 3; i++ { - err = handler.sendMessage(interfaces.KafkaMessage{ - Topic: "push-game_gcm", - Value: msgBytes, - }) - s.NoError(err) - s.Equal(int64(i), handler.sentMessages) - s.Equal(i, len(handler.pendingMessages)) - } - - go func() { - err := handler.sendMessage(interfaces.KafkaMessage{ - Topic: "push-game_gcm", - Value: msgBytes, - }) - s.Require().NoError(err) - }() - - <-handler.pendingMessages - s.Eventually( - func() bool { - gcmResMutex.Lock() - defer gcmResMutex.Unlock() - return handler.sentMessages == 4 - }, - 5*time.Second, - 100*time.Millisecond, - ) - }) -} - -func (s *GCMMessageHandlerTestSuite) TestCleanCache() { - s.Run("should remove from push queue after timeout", func() { - handler, _, _, mockKafkaProducer := s.setupHandler() - mockKafkaProducer.StartConsumingMessagesInProduceChannel() - err := handler.sendMessage(interfaces.KafkaMessage{ - Topic: "push-game_gcm", - Value: []byte(`{ "aps" : { "alert" : "Hello HTTP/2" } }`), - }) - s.Require().NoError(err) - - go handler.CleanMetadataCache() - - time.Sleep(2 * time.Duration(s.vConfig.GetInt("feedback.cache.requestTimeout")) * time.Millisecond) - - s.True(handler.requestsHeap.Empty()) - handler.inflightMessagesMetadataLock.Lock() - s.Empty(handler.InflightMessagesMetadata) - handler.inflightMessagesMetadataLock.Unlock() - }) - - s.Run("should succeed if request gets a response", func() { - handler, _, _, mockKafkaProducer := s.setupHandler() - mockKafkaProducer.StartConsumingMessagesInProduceChannel() - err := handler.sendMessage(interfaces.KafkaMessage{ - Topic: "push-game_gcm", - Value: []byte(`{ "aps" : { "alert" : "Hello HTTP/2" } }`), - }) - s.Require().NoError(err) - - go handler.CleanMetadataCache() - - res := gcm.CCSMessage{ - From: "testToken1", - MessageID: "idTest1", - MessageType: "ack", - Category: "testCategory", - } - err = handler.handleGCMResponse(res) - s.Require().NoError(err) - - time.Sleep(2 * time.Duration(s.vConfig.GetInt("feedback.cache.requestTimeout")) * time.Millisecond) - - s.True(handler.requestsHeap.Empty()) - handler.inflightMessagesMetadataLock.Lock() - s.Empty(handler.InflightMessagesMetadata) - handler.inflightMessagesMetadataLock.Unlock() - }) - - s.Run("should handle all responses or remove them after timeout", func() { - handler, _, _, mockKafkaProducer := s.setupHandler() - mockKafkaProducer.StartConsumingMessagesInProduceChannel() - n := 10 - sendRequests := func() { - for i := 0; i < n; i++ { - err := handler.sendMessage(interfaces.KafkaMessage{ - Topic: "push-game_gcm", - Value: []byte(`{ "aps" : { "alert" : "Hello HTTP/2" } }`), - }) - s.Require().NoError(err) - } - } - - handleResponses := func() { - for i := 0; i < n/2; i++ { - res := gcm.CCSMessage{ - From: "testToken1", - MessageID: "idTest1", - MessageType: "ack", - Category: "testCategory", - } - - err := handler.handleGCMResponse(res) - s.Require().NoError(err) - } - } - - go handler.CleanMetadataCache() - go sendRequests() - time.Sleep(2 * time.Duration(s.vConfig.GetInt("feedback.cache.requestTimeout")) * time.Millisecond) - - go handleResponses() - time.Sleep(2 * time.Duration(s.vConfig.GetInt("feedback.cache.requestTimeout")) * time.Millisecond) - - s.True(handler.requestsHeap.Empty()) - handler.inflightMessagesMetadataLock.Lock() - s.Empty(handler.InflightMessagesMetadata) - handler.inflightMessagesMetadataLock.Unlock() - }) -} - -func (s *GCMMessageHandlerTestSuite) TestLogStats() { - s.Run("should log stats and reset them", func() { - handler, _, _, _ := s.setupHandler() - handler.sentMessages = 100 - handler.responsesReceived = 90 - handler.successesReceived = 60 - handler.failuresReceived = 30 - handler.ignoredMessages = 10 - - go handler.LogStats() - - s.Eventually(func() bool { - gcmResMutex.Lock() - defer gcmResMutex.Unlock() - return handler.sentMessages == int64(0) && - handler.responsesReceived == int64(0) && - handler.successesReceived == int64(0) && - handler.failuresReceived == int64(0) && - handler.ignoredMessages == int64(0) - }, time.Second, time.Millisecond*100) - }) -} - -func (s *GCMMessageHandlerTestSuite) TestStatsReporter() { - s.Run("should call HandleNotificationSent upon message sent to queue", func() { - handler, _, mockStatsdClient, mockKafkaProducer := s.setupHandler() - mockKafkaProducer.StartConsumingMessagesInProduceChannel() - ttl := uint(0) - msg := &gcm.XMPPMessage{ - TimeToLive: &ttl, - DeliveryReceiptRequested: false, - DryRun: true, - To: uuid.NewV4().String(), - Data: map[string]interface{}{}, - } - msgBytes, err := json.Marshal(msg) - s.NoError(err) - - kafkaMessage := interfaces.KafkaMessage{ - Game: "game", - Topic: "push-game_gcm", - Value: msgBytes, - } - err = handler.sendMessage(kafkaMessage) - s.NoError(err) - - err = handler.sendMessage(kafkaMessage) - s.NoError(err) - s.Equal(int64(2), mockStatsdClient.Counts["sent"]) - }) - - s.Run("should call HandleNotificationSuccess upon response received", func() { - handler, _, mockStatsdClient, mockKafkaProducer := s.setupHandler() - mockKafkaProducer.StartConsumingMessagesInProduceChannel() - res := gcm.CCSMessage{} - err := handler.handleGCMResponse(res) - s.Require().NoError(err) - err = handler.handleGCMResponse(res) - s.Require().NoError(err) - - s.Equal(int64(2), mockStatsdClient.Counts["ack"]) - }) - - s.Run("should call HandleNotificationFailure upon error response received", func() { - handler, _, mockStatsdClient, mockKafkaProducer := s.setupHandler() - mockKafkaProducer.StartConsumingMessagesInProduceChannel() - res := gcm.CCSMessage{ - Error: "DEVICE_UNREGISTERED", - } - err := handler.handleGCMResponse(res) - s.Error(err) - - s.Equal(int64(1), mockStatsdClient.Counts["failed"]) - }) - -} - -func (s *GCMMessageHandlerTestSuite) TestFeedbackReporter() { - s.Run("should include a timestamp in feedback root and the hostname in metadata", func() { - handler, _, _, mockKafkaProducer := s.setupHandler() - timestampNow := time.Now().Unix() - hostname, err := os.Hostname() - s.Require().NoError(err) - - metadata := map[string]interface{}{ - "some": "metadata", - "timestamp": timestampNow, - "hostname": hostname, - "game": "game", - "platform": "gcm", - } - handler.inflightMessagesMetadataLock.Lock() - handler.InflightMessagesMetadata["idTest1"] = metadata - handler.inflightMessagesMetadataLock.Unlock() - res := gcm.CCSMessage{ - From: "testToken1", - MessageID: "idTest1", - MessageType: "ack", - Category: "testCategory", - } - go func() { - err := handler.handleGCMResponse(res) - s.Require().NoError(err) - }() - - fromKafka := &CCSMessageWithMetadata{} - msg := <-mockKafkaProducer.ProduceChannel() - err = json.Unmarshal(msg.Value, fromKafka) - s.Require().NoError(err) - s.Equal(timestampNow, fromKafka.Timestamp) - s.Equal(hostname, fromKafka.Metadata["hostname"]) - }) - - s.Run("should send feedback if success and metadata is present", func() { - handler, _, _, mockKafkaProducer := s.setupHandler() - metadata := map[string]interface{}{ - "some": "metadata", - "timestamp": time.Now().Unix(), - "game": "game", - "platform": "gcm", - } - - handler.inflightMessagesMetadataLock.Lock() - handler.InflightMessagesMetadata["idTest1"] = metadata - handler.inflightMessagesMetadataLock.Unlock() - - res := gcm.CCSMessage{ - From: "testToken1", - MessageID: "idTest1", - MessageType: "ack", - Category: "testCategory", - } - go func() { - err := handler.handleGCMResponse(res) - s.Require().NoError(err) - }() - - fromKafka := &CCSMessageWithMetadata{} - msg := <-mockKafkaProducer.ProduceChannel() - err := json.Unmarshal(msg.Value, fromKafka) - s.Require().NoError(err) - s.Equal(res.From, fromKafka.From) - s.Equal(res.MessageID, fromKafka.MessageID) - s.Equal(res.MessageType, fromKafka.MessageType) - s.Equal(res.Category, fromKafka.Category) - s.Equal(metadata["some"], fromKafka.Metadata["some"]) - }) - - s.Run("should send feedback if success and metadata is not present", func() { - handler, _, _, mockKafkaProducer := s.setupHandler() - res := gcm.CCSMessage{ - From: "testToken1", - MessageID: "idTest1", - MessageType: "ack", - Category: "testCategory", - } - go func() { - err := handler.handleGCMResponse(res) - s.Require().NoError(err) - }() - - fromKafka := &CCSMessageWithMetadata{} - msg := <-mockKafkaProducer.ProduceChannel() - err := json.Unmarshal(msg.Value, fromKafka) - s.Require().NoError(err) - s.Equal(res.From, fromKafka.From) - s.Equal(res.MessageID, fromKafka.MessageID) - s.Equal(res.MessageType, fromKafka.MessageType) - s.Equal(res.Category, fromKafka.Category) - s.Nil(fromKafka.Metadata) - }) - - s.Run("should send feedback if error and metadata is present and token should be deleted", func() { - handler, _, _, mockKafkaProducer := s.setupHandler() - metadata := map[string]interface{}{ - "some": "metadata", - "timestamp": time.Now().Unix(), - "game": "game", - "platform": "gcm", - } - - handler.inflightMessagesMetadataLock.Lock() - handler.InflightMessagesMetadata["idTest1"] = metadata - handler.inflightMessagesMetadataLock.Unlock() - - res := gcm.CCSMessage{ - From: "testToken1", - MessageID: "idTest1", - MessageType: "nack", - Category: "testCategory", - Error: "BAD_REGISTRATION", - } - go func() { - err := handler.handleGCMResponse(res) - s.Error(err) - }() - - fromKafka := &CCSMessageWithMetadata{} - msg := <-mockKafkaProducer.ProduceChannel() - err := json.Unmarshal(msg.Value, fromKafka) - s.Require().NoError(err) - s.Equal(res.From, fromKafka.From) - s.Equal(res.MessageID, fromKafka.MessageID) - s.Equal(res.MessageType, fromKafka.MessageType) - s.Equal(res.Category, fromKafka.Category) - s.Equal(res.Error, fromKafka.Error) - s.Equal(metadata["some"], fromKafka.Metadata["some"]) - s.True(fromKafka.Metadata["deleteToken"].(bool)) - }) - - s.Run("should send feedback if error and metadata is present and token should not be deleted", func() { - handler, _, _, mockKafkaProducer := s.setupHandler() - metadata := map[string]interface{}{ - "some": "metadata", - "timestamp": time.Now().Unix(), - "game": "game", - "platform": "gcm", - } - - handler.inflightMessagesMetadataLock.Lock() - handler.InflightMessagesMetadata["idTest1"] = metadata - handler.inflightMessagesMetadataLock.Unlock() - - res := gcm.CCSMessage{ - From: "testToken1", - MessageID: "idTest1", - MessageType: "nack", - Category: "testCategory", - Error: "INVALID_JSON", - } - go func() { - err := handler.handleGCMResponse(res) - s.Error(err) - }() - - fromKafka := &CCSMessageWithMetadata{} - msg := <-mockKafkaProducer.ProduceChannel() - err := json.Unmarshal(msg.Value, fromKafka) - s.Require().NoError(err) - s.Equal(res.From, fromKafka.From) - s.Equal(res.MessageID, fromKafka.MessageID) - s.Equal(res.MessageType, fromKafka.MessageType) - s.Equal(res.Category, fromKafka.Category) - s.Equal(res.Error, fromKafka.Error) - s.Equal(metadata["some"], fromKafka.Metadata["some"]) - s.Nil(fromKafka.Metadata["deleteToken"]) - }) - s.Run("should send feedback if error and metadata is not present", func() { - handler, _, _, mockKafkaProducer := s.setupHandler() - res := gcm.CCSMessage{ - From: "testToken1", - MessageID: "idTest1", - MessageType: "nack", - Category: "testCategory", - Error: "BAD_REGISTRATION", - } - go func() { - err := handler.handleGCMResponse(res) - s.Error(err) - }() - - fromKafka := &CCSMessageWithMetadata{} - msg := <-mockKafkaProducer.ProduceChannel() - err := json.Unmarshal(msg.Value, fromKafka) - s.Require().NoError(err) - s.Equal(res.From, fromKafka.From) - s.Equal(res.MessageID, fromKafka.MessageID) - s.Equal(res.MessageType, fromKafka.MessageType) - s.Equal(res.Category, fromKafka.Category) - s.Equal(res.Error, fromKafka.Error) - s.Nil(fromKafka.Metadata) - }) -} diff --git a/extensions/handler/message_handler_test.go b/extensions/handler/message_handler_test.go deleted file mode 100644 index c7c328e..0000000 --- a/extensions/handler/message_handler_test.go +++ /dev/null @@ -1,323 +0,0 @@ -package handler - -import ( - "context" - "encoding/json" - "fmt" - "os" - "testing" - "time" - - "github.com/google/uuid" - "github.com/sirupsen/logrus/hooks/test" - "github.com/spf13/viper" - "github.com/stretchr/testify/suite" - "github.com/topfreegames/pusher/config" - pushererrors "github.com/topfreegames/pusher/errors" - "github.com/topfreegames/pusher/extensions" - "github.com/topfreegames/pusher/interfaces" - "github.com/topfreegames/pusher/mocks" - mock_interfaces "github.com/topfreegames/pusher/mocks/firebase" - "go.uber.org/mock/gomock" -) - -const concurrentWorkers = 5 - -type MessageHandlerTestSuite struct { - suite.Suite - vConfig *viper.Viper - config *config.Config - game string - - mockClient *mock_interfaces.MockPushClient - mockStatsdClient *mocks.StatsDClientMock - mockKafkaProducer *mocks.KafkaProducerClientMock - - handler *messageHandler -} - -func TestMessageHandlerSuite(t *testing.T) { - suite.Run(t, new(MessageHandlerTestSuite)) -} - -func (s *MessageHandlerTestSuite) SetupSuite() { - file := os.Getenv("CONFIG_FILE") - if file == "" { - file = "../../config/test.yaml" - } - - config, vConfig, err := config.NewConfigAndViper(file) - s.Require().NoError(err) - s.config = config - s.vConfig = vConfig - s.game = "game" -} - -func (s *MessageHandlerTestSuite) SetupSubTest() { - ctrl := gomock.NewController(s.T()) - s.mockClient = mock_interfaces.NewMockPushClient(ctrl) - - l, _ := test.NewNullLogger() - - s.mockStatsdClient = mocks.NewStatsDClientMock() - statsD, err := extensions.NewStatsD(s.vConfig, l, s.mockStatsdClient) - s.Require().NoError(err) - - s.mockKafkaProducer = mocks.NewKafkaProducerClientMock() - kc, err := extensions.NewKafkaProducer(s.vConfig, l, s.mockKafkaProducer) - s.Require().NoError(err) - - statsClients := []interfaces.StatsReporter{statsD} - feedbackClients := []interfaces.FeedbackReporter{kc} - mockRateLimiter := mocks.NewRateLimiterMock() - - cfg := newDefaultMessageHandlerConfig() - cfg.concurrentResponseHandlers = concurrentWorkers - handler := &messageHandler{ - app: s.game, - client: s.mockClient, - feedbackReporters: feedbackClients, - statsReporters: statsClients, - rateLimiter: mockRateLimiter, - logger: l, - config: cfg, - sendPushConcurrencyControl: make(chan interface{}, concurrentWorkers), - responsesChannel: make(chan struct { - msg interfaces.Message - error error - }, concurrentWorkers), - } - - // sendPushConcurrencyControl is a channel that controls the number of concurrent workers - // that are going to be sending messages to FCM. Once the message is sent, the struct is going - // to be pushed back into the channel buffer to be reused. - for i := 0; i < concurrentWorkers; i++ { - handler.sendPushConcurrencyControl <- struct{}{} - } - - s.NoError(err) - s.Require().NotNil(handler) - - s.handler = handler -} - -func (s *MessageHandlerTestSuite) TestSendMessage() { - ctx := context.Background() - s.Run("should do nothing for bad message format", func() { - message := interfaces.KafkaMessage{ - Value: []byte("bad message"), - } - s.handler.HandleMessages(ctx, message) - - s.handler.statsMutex.Lock() - s.Equal(int64(0), s.handler.stats.sent) - s.Equal(int64(0), s.handler.stats.failures) - s.Equal(int64(0), s.handler.stats.ignored) - s.handler.statsMutex.Unlock() - }) - - s.Run("should ignore message if it has expired", func() { - message := interfaces.Message{} - km := &extensions.KafkaGCMMessage{ - Message: message, - PushExpiry: extensions.MakeTimestamp() - time.Hour.Milliseconds(), - } - bytes, err := json.Marshal(km) - s.Require().NoError(err) - - s.handler.HandleMessages(ctx, interfaces.KafkaMessage{Value: bytes}) - s.handler.statsMutex.Lock() - s.Equal(int64(1), s.handler.stats.ignored) - s.handler.statsMutex.Unlock() - }) - - s.Run("should report failure if cannot send message", func() { - ttl := uint(0) - token := "token" - metadata := map[string]interface{}{ - "some": "metadata", - "game": "game", - "platform": "gcm", - } - km := &extensions.KafkaGCMMessage{ - Message: interfaces.Message{ - TimeToLive: &ttl, - DeliveryReceiptRequested: false, - DryRun: true, - To: token, - Data: map[string]interface{}{ - "title": "notification", - }, - }, - Metadata: metadata, - PushExpiry: extensions.MakeTimestamp() + int64(1000000), - } - - expected := interfaces.Message{ - TimeToLive: &ttl, - DeliveryReceiptRequested: false, - DryRun: true, - To: token, - Data: map[string]interface{}{ - "title": "notification", - "some": "metadata", - "game": "game", - "platform": "gcm", - }, - } - - bytes, err := json.Marshal(km) - s.Require().NoError(err) - - s.mockClient.EXPECT(). - SendPush(gomock.Any(), expected). - Return(pushererrors.NewPushError("INVALID_TOKEN", "invalid token")) - - s.handler.HandleMessages(ctx, interfaces.KafkaMessage{Value: bytes}) - s.handler.HandleResponses() - - select { - case m := <-s.mockKafkaProducer.ProduceChannel(): - val := &FeedbackResponse{} - err = json.Unmarshal(m.Value, val) - s.NoError(err) - s.Equal("INVALID_TOKEN", val.Error) - case <-time.After(time.Second * 1): - s.Fail("did not send feedback to kafka") - } - - s.handler.statsMutex.Lock() - s.Equal(int64(1), s.handler.stats.failures) - s.handler.statsMutex.Unlock() - - s.Equal(int64(1), s.mockStatsdClient.Counts["failed"]) - }) - - s.Run("should report sent and success if message was sent", func() { - ttl := uint(0) - token := "token" - metadata := map[string]interface{}{ - "some": "metadata", - "game": "game", - "platform": "gcm", - } - km := &extensions.KafkaGCMMessage{ - Message: interfaces.Message{ - TimeToLive: &ttl, - DeliveryReceiptRequested: false, - DryRun: true, - To: token, - Data: map[string]interface{}{ - "title": "notification", - }, - }, - Metadata: metadata, - PushExpiry: extensions.MakeTimestamp() + int64(1000000), - } - - expected := interfaces.Message{ - TimeToLive: &ttl, - DeliveryReceiptRequested: false, - DryRun: true, - To: token, - Data: map[string]interface{}{ - "title": "notification", - "some": "metadata", - "game": "game", - "platform": "gcm", - }, - } - - bytes, err := json.Marshal(km) - s.Require().NoError(err) - - s.mockClient.EXPECT(). - SendPush(gomock.Any(), expected). - Return(nil) - - s.handler.HandleMessages(ctx, interfaces.KafkaMessage{Value: bytes}) - s.handler.HandleResponses() - - select { - case m := <-s.mockKafkaProducer.ProduceChannel(): - val := &FeedbackResponse{} - err = json.Unmarshal(m.Value, val) - s.NoError(err) - s.Empty(val.Error) - case <-time.After(time.Second * 1): - s.Fail("did not send feedback to kafka") - } - - s.handler.statsMutex.Lock() - s.Equal(int64(1), s.handler.stats.sent) - s.handler.statsMutex.Unlock() - s.Equal(int64(1), s.mockStatsdClient.Counts["sent"]) - s.Equal(int64(1), s.mockStatsdClient.Counts["ack"]) - }) - - s.Run("should not lock sendPushConcurrencyControl when sending multiple messages", func() { - newMessage := func() extensions.KafkaGCMMessage { - ttl := uint(0) - token := uuid.NewString() - title := fmt.Sprintf("title - %s", uuid.NewString()) - metadata := map[string]interface{}{ - "some": "metadata", - "game": "game", - "platform": "gcm", - } - - km := extensions.KafkaGCMMessage{ - Message: interfaces.Message{ - TimeToLive: &ttl, - DeliveryReceiptRequested: false, - DryRun: true, - To: token, - Data: map[string]interface{}{ - "title": title, - }, - }, - Metadata: metadata, - PushExpiry: extensions.MakeTimestamp() + int64(1000000), - } - - return km - } - - go s.handler.HandleResponses() - - qtyMsgs := 20 - - s.mockClient.EXPECT(). - SendPush(gomock.Any(), gomock.Any()). - Return(nil). - Times(qtyMsgs) - - for i := 0; i < qtyMsgs; i++ { - km := newMessage() - bytes, err := json.Marshal(km) - s.Require().NoError(err) - - go s.handler.HandleMessages(ctx, interfaces.KafkaMessage{Value: bytes}) - } - - for i := 0; i < qtyMsgs; i++ { - select { - case m := <-s.mockKafkaProducer.ProduceChannel(): - val := &FeedbackResponse{} - err := json.Unmarshal(m.Value, val) - s.NoError(err) - s.Empty(val.Error) - case <-time.After(time.Second * 1): - s.Fail("did not send feedback to kafka") - } - } - - time.Sleep(2 * time.Second) - - s.handler.statsMutex.Lock() - s.Equal(int64(20), s.handler.stats.sent) - s.handler.statsMutex.Unlock() - s.Equal(int64(20), s.mockStatsdClient.Counts["sent"]) - s.Equal(int64(20), s.mockStatsdClient.Counts["ack"]) - }) -} diff --git a/extensions/rate_limiter.go b/extensions/rate_limiter.go index a1b170a..8311854 100644 --- a/extensions/rate_limiter.go +++ b/extensions/rate_limiter.go @@ -66,7 +66,7 @@ func (r rateLimiter) Allow(ctx context.Context, device string, game string, plat if err != nil && !errors.Is(err, redis.Nil) { // Something went wrong, return true to avoid blocking notifications. l.WithError(err).Error("could not get current rate in redis") - statsReporterNotificationRateLimitFailed(r.statsReporters, game, platform) + StatsReporterNotificationRateLimitFailed(r.statsReporters, game, platform) return true } if errors.Is(err, redis.Nil) { @@ -78,7 +78,7 @@ func (r rateLimiter) Allow(ctx context.Context, device string, game string, plat if err != nil { // Something went wrong, return true to avoid blocking notifications. l.WithError(err).Error("current rate is invalid") - statsReporterNotificationRateLimitFailed(r.statsReporters, game, platform) + StatsReporterNotificationRateLimitFailed(r.statsReporters, game, platform) return true } @@ -94,7 +94,7 @@ func (r rateLimiter) Allow(ctx context.Context, device string, game string, plat if err != nil { // Allow the operation even if the transaction fails, to avoid blocking notifications. l.WithError(err).Error("increment to current rate failed") - statsReporterNotificationRateLimitFailed(r.statsReporters, game, platform) + StatsReporterNotificationRateLimitFailed(r.statsReporters, game, platform) } l.WithField("currentRate", current).Debug("current rate allows message") diff --git a/feedback/broker.go b/feedback/broker.go index 6ae1fb3..9364386 100644 --- a/feedback/broker.go +++ b/feedback/broker.go @@ -114,7 +114,7 @@ func (b *Broker) Stop() { func (b *Broker) processMessages() { l := b.Logger.WithField( - "method", "processMessages", + "method", "processMessages", ) for { diff --git a/interfaces/message_handler.go b/interfaces/message_handler.go index 1221d6a..5b1659d 100644 --- a/interfaces/message_handler.go +++ b/interfaces/message_handler.go @@ -28,6 +28,4 @@ import "context" type MessageHandler interface { HandleMessages(ctx context.Context, msg KafkaMessage) HandleResponses() - LogStats() - CleanMetadataCache() } diff --git a/mocks/firebase/client.go b/mocks/interfaces/client.go similarity index 94% rename from mocks/firebase/client.go rename to mocks/interfaces/client.go index 4ea5ae2..cb395db 100644 --- a/mocks/firebase/client.go +++ b/mocks/interfaces/client.go @@ -3,7 +3,7 @@ // // Generated by this command: // -// mockgen -source=interfaces/client.go -destination=mocks/firebase/client.go +// mockgen -source=interfaces/client.go -destination=mocks/interfaces/client.go // // Package mock_interfaces is a generated GoMock package. diff --git a/mocks/interfaces/message_handler.go b/mocks/interfaces/message_handler.go index 5e90a85..0730570 100644 --- a/mocks/interfaces/message_handler.go +++ b/mocks/interfaces/message_handler.go @@ -75,15 +75,3 @@ func (mr *MockMessageHandlerMockRecorder) HandleResponses() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleResponses", reflect.TypeOf((*MockMessageHandler)(nil).HandleResponses)) } - -// LogStats mocks base method. -func (m *MockMessageHandler) LogStats() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "LogStats") -} - -// LogStats indicates an expected call of LogStats. -func (mr *MockMessageHandlerMockRecorder) LogStats() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LogStats", reflect.TypeOf((*MockMessageHandler)(nil).LogStats)) -} diff --git a/mocks/interfaces/rate_limiter.go b/mocks/interfaces/rate_limiter.go new file mode 100644 index 0000000..d0d2eb4 --- /dev/null +++ b/mocks/interfaces/rate_limiter.go @@ -0,0 +1,54 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: interfaces/rate_limiter.go +// +// Generated by this command: +// +// mockgen -source=interfaces/rate_limiter.go -destination=mocks/interfaces/rate_limiter.go +// + +// Package mock_interfaces is a generated GoMock package. +package mock_interfaces + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockRateLimiter is a mock of RateLimiter interface. +type MockRateLimiter struct { + ctrl *gomock.Controller + recorder *MockRateLimiterMockRecorder +} + +// MockRateLimiterMockRecorder is the mock recorder for MockRateLimiter. +type MockRateLimiterMockRecorder struct { + mock *MockRateLimiter +} + +// NewMockRateLimiter creates a new mock instance. +func NewMockRateLimiter(ctrl *gomock.Controller) *MockRateLimiter { + mock := &MockRateLimiter{ctrl: ctrl} + mock.recorder = &MockRateLimiterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRateLimiter) EXPECT() *MockRateLimiterMockRecorder { + return m.recorder +} + +// Allow mocks base method. +func (m *MockRateLimiter) Allow(ctx context.Context, device, game, platform string) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Allow", ctx, device, game, platform) + ret0, _ := ret[0].(bool) + return ret0 +} + +// Allow indicates an expected call of Allow. +func (mr *MockRateLimiterMockRecorder) Allow(ctx, device, game, platform any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Allow", reflect.TypeOf((*MockRateLimiter)(nil).Allow), ctx, device, game, platform) +} diff --git a/mocks/interfaces/stats_reporter.go b/mocks/interfaces/stats_reporter.go new file mode 100644 index 0000000..d5d4a4b --- /dev/null +++ b/mocks/interfaces/stats_reporter.go @@ -0,0 +1,183 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: interfaces/stats_reporter.go +// +// Generated by this command: +// +// mockgen -source=interfaces/stats_reporter.go -destination=mocks/interfaces/stats_reporter.go +// + +// Package mock_interfaces is a generated GoMock package. +package mock_interfaces + +import ( + reflect "reflect" + time "time" + + errors "github.com/topfreegames/pusher/errors" + gomock "go.uber.org/mock/gomock" +) + +// MockStatsReporter is a mock of StatsReporter interface. +type MockStatsReporter struct { + ctrl *gomock.Controller + recorder *MockStatsReporterMockRecorder +} + +// MockStatsReporterMockRecorder is the mock recorder for MockStatsReporter. +type MockStatsReporterMockRecorder struct { + mock *MockStatsReporter +} + +// NewMockStatsReporter creates a new mock instance. +func NewMockStatsReporter(ctrl *gomock.Controller) *MockStatsReporter { + mock := &MockStatsReporter{ctrl: ctrl} + mock.recorder = &MockStatsReporterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStatsReporter) EXPECT() *MockStatsReporterMockRecorder { + return m.recorder +} + +// HandleNotificationFailure mocks base method. +func (m *MockStatsReporter) HandleNotificationFailure(game, platform string, err *errors.PushError) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "HandleNotificationFailure", game, platform, err) +} + +// HandleNotificationFailure indicates an expected call of HandleNotificationFailure. +func (mr *MockStatsReporterMockRecorder) HandleNotificationFailure(game, platform, err any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleNotificationFailure", reflect.TypeOf((*MockStatsReporter)(nil).HandleNotificationFailure), game, platform, err) +} + +// HandleNotificationSent mocks base method. +func (m *MockStatsReporter) HandleNotificationSent(game, platform string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "HandleNotificationSent", game, platform) +} + +// HandleNotificationSent indicates an expected call of HandleNotificationSent. +func (mr *MockStatsReporterMockRecorder) HandleNotificationSent(game, platform any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleNotificationSent", reflect.TypeOf((*MockStatsReporter)(nil).HandleNotificationSent), game, platform) +} + +// HandleNotificationSuccess mocks base method. +func (m *MockStatsReporter) HandleNotificationSuccess(game, platform string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "HandleNotificationSuccess", game, platform) +} + +// HandleNotificationSuccess indicates an expected call of HandleNotificationSuccess. +func (mr *MockStatsReporterMockRecorder) HandleNotificationSuccess(game, platform any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleNotificationSuccess", reflect.TypeOf((*MockStatsReporter)(nil).HandleNotificationSuccess), game, platform) +} + +// InitializeFailure mocks base method. +func (m *MockStatsReporter) InitializeFailure(game, platform string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "InitializeFailure", game, platform) +} + +// InitializeFailure indicates an expected call of InitializeFailure. +func (mr *MockStatsReporterMockRecorder) InitializeFailure(game, platform any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InitializeFailure", reflect.TypeOf((*MockStatsReporter)(nil).InitializeFailure), game, platform) +} + +// NotificationRateLimitFailed mocks base method. +func (m *MockStatsReporter) NotificationRateLimitFailed(game, platform string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "NotificationRateLimitFailed", game, platform) +} + +// NotificationRateLimitFailed indicates an expected call of NotificationRateLimitFailed. +func (mr *MockStatsReporterMockRecorder) NotificationRateLimitFailed(game, platform any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationRateLimitFailed", reflect.TypeOf((*MockStatsReporter)(nil).NotificationRateLimitFailed), game, platform) +} + +// NotificationRateLimitReached mocks base method. +func (m *MockStatsReporter) NotificationRateLimitReached(game, platform string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "NotificationRateLimitReached", game, platform) +} + +// NotificationRateLimitReached indicates an expected call of NotificationRateLimitReached. +func (mr *MockStatsReporterMockRecorder) NotificationRateLimitReached(game, platform any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationRateLimitReached", reflect.TypeOf((*MockStatsReporter)(nil).NotificationRateLimitReached), game, platform) +} + +// ReportFirebaseLatency mocks base method. +func (m *MockStatsReporter) ReportFirebaseLatency(latencyMs time.Duration, game string, labels ...string) { + m.ctrl.T.Helper() + varargs := []any{latencyMs, game} + for _, a := range labels { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "ReportFirebaseLatency", varargs...) +} + +// ReportFirebaseLatency indicates an expected call of ReportFirebaseLatency. +func (mr *MockStatsReporterMockRecorder) ReportFirebaseLatency(latencyMs, game any, labels ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{latencyMs, game}, labels...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReportFirebaseLatency", reflect.TypeOf((*MockStatsReporter)(nil).ReportFirebaseLatency), varargs...) +} + +// ReportGoStats mocks base method. +func (m *MockStatsReporter) ReportGoStats(numGoRoutines int, allocatedAndNotFreed, heapObjects, nextGCBytes, pauseGCNano uint64) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReportGoStats", numGoRoutines, allocatedAndNotFreed, heapObjects, nextGCBytes, pauseGCNano) +} + +// ReportGoStats indicates an expected call of ReportGoStats. +func (mr *MockStatsReporterMockRecorder) ReportGoStats(numGoRoutines, allocatedAndNotFreed, heapObjects, nextGCBytes, pauseGCNano any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReportGoStats", reflect.TypeOf((*MockStatsReporter)(nil).ReportGoStats), numGoRoutines, allocatedAndNotFreed, heapObjects, nextGCBytes, pauseGCNano) +} + +// ReportMetricCount mocks base method. +func (m *MockStatsReporter) ReportMetricCount(metric string, value int64, game, platform string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReportMetricCount", metric, value, game, platform) +} + +// ReportMetricCount indicates an expected call of ReportMetricCount. +func (mr *MockStatsReporterMockRecorder) ReportMetricCount(metric, value, game, platform any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReportMetricCount", reflect.TypeOf((*MockStatsReporter)(nil).ReportMetricCount), metric, value, game, platform) +} + +// ReportMetricGauge mocks base method. +func (m *MockStatsReporter) ReportMetricGauge(metric string, value float64, game, platform string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReportMetricGauge", metric, value, game, platform) +} + +// ReportMetricGauge indicates an expected call of ReportMetricGauge. +func (mr *MockStatsReporterMockRecorder) ReportMetricGauge(metric, value, game, platform any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReportMetricGauge", reflect.TypeOf((*MockStatsReporter)(nil).ReportMetricGauge), metric, value, game, platform) +} + +// ReportSendNotificationLatency mocks base method. +func (m *MockStatsReporter) ReportSendNotificationLatency(latencyMs time.Duration, game, platform string, labels ...string) { + m.ctrl.T.Helper() + varargs := []any{latencyMs, game, platform} + for _, a := range labels { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "ReportSendNotificationLatency", varargs...) +} + +// ReportSendNotificationLatency indicates an expected call of ReportSendNotificationLatency. +func (mr *MockStatsReporterMockRecorder) ReportSendNotificationLatency(latencyMs, game, platform any, labels ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{latencyMs, game, platform}, labels...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReportSendNotificationLatency", reflect.TypeOf((*MockStatsReporter)(nil).ReportSendNotificationLatency), varargs...) +} diff --git a/mocks/rate_limiter.go b/mocks/rate_limiter.go deleted file mode 100644 index a05a6a6..0000000 --- a/mocks/rate_limiter.go +++ /dev/null @@ -1,14 +0,0 @@ -package mocks - -import "context" - -type rateLimiterMock struct { -} - -func NewRateLimiterMock() *rateLimiterMock { - return &rateLimiterMock{} -} - -func (rl *rateLimiterMock) Allow(ctx context.Context, device, game, platform string) bool { - return true -} diff --git a/pusher/apns.go b/pusher/apns.go index 6d304af..40829b5 100644 --- a/pusher/apns.go +++ b/pusher/apns.go @@ -30,6 +30,7 @@ import ( "github.com/spf13/viper" "github.com/topfreegames/pusher/config" "github.com/topfreegames/pusher/extensions" + "github.com/topfreegames/pusher/extensions/apns" "github.com/topfreegames/pusher/interfaces" ) @@ -107,7 +108,7 @@ func (a *APNSPusher) configure(queue interfaces.APNSPushQueue, db interfaces.DB, "topic": topic, }).Info("configuring apns message handler") - handler, err := extensions.NewAPNSMessageHandler( + handler, err := apns.NewAPNSMessageHandler( authKeyPath, keyID, teamID, @@ -120,7 +121,6 @@ func (a *APNSPusher) configure(queue interfaces.APNSPushQueue, db interfaces.DB, a.StatsReporters, a.feedbackReporters, queue, - interfaces.ConsumptionManager(q), extensions.NewRateLimiter(rateLimit, a.ViperConfig, a.StatsReporters, l.Logger), ) if err == nil { @@ -129,11 +129,11 @@ func (a *APNSPusher) configure(queue interfaces.APNSPushQueue, db interfaces.DB, for _, statsReporter := range a.StatsReporters { statsReporter.InitializeFailure(k, "apns") } - return fmt.Errorf("failed to initialize apns handler for %s", k) + return fmt.Errorf("failed to initialize apns firebase for %s", k) } } if len(a.MessageHandler) == 0 { - return errors.New("could not initilize any app") + return errors.New("could not initialize any app") } return nil } diff --git a/pusher/gcm.go b/pusher/gcm.go index 4e55ee5..1c51514 100644 --- a/pusher/gcm.go +++ b/pusher/gcm.go @@ -25,13 +25,13 @@ package pusher import ( "context" "fmt" + "github.com/topfreegames/pusher/extensions/firebase/client" "github.com/sirupsen/logrus" "github.com/spf13/viper" "github.com/topfreegames/pusher/config" "github.com/topfreegames/pusher/extensions" - "github.com/topfreegames/pusher/extensions/client" - "github.com/topfreegames/pusher/extensions/handler" + "github.com/topfreegames/pusher/extensions/firebase" "github.com/topfreegames/pusher/interfaces" ) @@ -101,41 +101,25 @@ func (g *GCMPusher) createMessageHandlerForApps(ctx context.Context) error { rateLimit := g.ViperConfig.GetInt("gcm.rateLimit.rpm") l = l.WithField("app", app) - if credentials != "" { // Firebase is configured, use new handler - pushClient, err := client.NewFirebaseClient(ctx, credentials, g.Logger) - if err != nil { - l.WithError(err).Error("could not create firebase client") - return fmt.Errorf("could not create firebase pushClient for all apps: %w", err) - } - l.Debug("created new message handler with firebase client") - g.MessageHandler[app] = handler.NewMessageHandler( - app, - pushClient, - g.feedbackReporters, - g.StatsReporters, - extensions.NewRateLimiter(rateLimit, g.ViperConfig, g.StatsReporters, l.Logger), - g.Logger, - g.Config.GCM.ConcurrentWorkers, - ) - } else { // Firebase credentials not yet configured, use legacy XMPP client - handler, err := extensions.NewGCMMessageHandler( - app, - g.IsProduction, - g.ViperConfig, - g.Logger, - g.Queue.PendingMessagesWaitGroup(), - g.StatsReporters, - g.feedbackReporters, - extensions.NewRateLimiter(rateLimit, g.ViperConfig, g.StatsReporters, l.Logger), - ) - if err != nil { - l.WithError(err).Error("could not create gcm message handler") - return fmt.Errorf("could not create gcm message handler for all apps: %w", err) - } - - l.Debug("created legacy message handler with xmpp client") - g.MessageHandler[app] = handler + if credentials == "" { + l.Fatalf("firebase credentials not found for %s", app) + } + pushClient, err := client.NewFirebaseClient(ctx, credentials, g.Logger) + if err != nil { + l.WithError(err).Error("could not create firebase client") + return fmt.Errorf("could not create firebase pushClient for all apps: %w", err) } + l.Debug("created new message handler with firebase client") + g.MessageHandler[app] = firebase.NewMessageHandler( + app, + pushClient, + g.feedbackReporters, + g.StatsReporters, + extensions.NewRateLimiter(rateLimit, g.ViperConfig, g.StatsReporters, l.Logger), + g.Queue.PendingMessagesWaitGroup(), + g.Logger, + g.Config.GCM.ConcurrentWorkers, + ) } return nil } diff --git a/pusher/pusher.go b/pusher/pusher.go index bae3726..7ae46fa 100644 --- a/pusher/pusher.go +++ b/pusher/pusher.go @@ -113,8 +113,6 @@ func (p *Pusher) Start(ctx context.Context) { go p.routeMessages(p.Queue.MessagesChannel()) for _, v := range p.MessageHandler { go v.HandleResponses() - go v.LogStats() - go v.CleanMetadataCache() } //nolint[:errcheck] go p.Queue.ConsumeLoop(ctx)