From 201f0e1a75d12e6c9c605f96db026bb82a8c516d Mon Sep 17 00:00:00 2001 From: Gustavo <25396922+gussf@users.noreply.github.com> Date: Wed, 19 Jun 2024 10:46:50 -0300 Subject: [PATCH] feat: add rate limiter to apns and gcm (#58) * feat: add rate limiter to apns and gcm * fix: add rate limiter mock and fix tests * try to fix tests * add redis to github worflow * fix: rate limiter config * chore: add rate limiter to firebase handler * fix: add tls config * chore: log level and config name * fix: tests * chore: try to fix pipeline * try to enable tls in pipeline * chore: add test flag to config to set tls * chore: fix port type * chore: add pr suggestions * chore: add labels to redis failure metric --- .github/workflows/integration-tests.yaml | 8 ++ config/default.yaml | 9 ++ config/docker_test.yaml | 10 +- config/test.yaml | 10 +- docker-compose-container-dev.yml | 11 +++ docker-compose.yml | 10 ++ e2e/fcm_e2e_test.go | 4 + extensions/apns_message_handler.go | 13 ++- extensions/apns_message_handler_test.go | 9 +- extensions/common.go | 12 +++ extensions/datadog_statsd.go | 18 ++++ extensions/gcm_message_handler.go | 13 ++- extensions/gcm_message_handler_test.go | 10 +- extensions/handler/message_handler.go | 16 ++++ extensions/handler/message_handler_test.go | 2 + extensions/rate_limiter.go | 102 +++++++++++++++++++++ extensions/rate_limiter_test.go | 76 +++++++++++++++ go.mod | 7 +- go.sum | 8 ++ interfaces/rate_limiter.go | 30 ++++++ interfaces/stats_reporter.go | 5 +- mocks/rate_limiter.go | 14 +++ pusher/apns.go | 3 + pusher/gcm.go | 4 + 24 files changed, 392 insertions(+), 12 deletions(-) create mode 100644 extensions/rate_limiter.go create mode 100644 extensions/rate_limiter_test.go create mode 100644 interfaces/rate_limiter.go create mode 100644 mocks/rate_limiter.go diff --git a/.github/workflows/integration-tests.yaml b/.github/workflows/integration-tests.yaml index 2c754b8..69c41f1 100644 --- a/.github/workflows/integration-tests.yaml +++ b/.github/workflows/integration-tests.yaml @@ -49,6 +49,14 @@ jobs: --health-retries 5 statsd: image: hopsoft/graphite-statsd + redis: + image: redis:6.0.9-alpine + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + steps: - uses: actions/checkout@v2 - name: Set up go vendor cache diff --git a/config/default.yaml b/config/default.yaml index e0324a4..d9eb1b3 100644 --- a/config/default.yaml +++ b/config/default.yaml @@ -1,6 +1,7 @@ --- gracefulShutdownTimeout: 30 apns: + rateLimit.rpm: 20 concurrentWorkers: 300 connectionPoolSize: 1 pushQueueSize: 100 @@ -14,6 +15,7 @@ apns: teamID: "ABC123DEFG" topic: "com.game.test" gcm: + rateLimit.rpm: 20 pingInterval: 30 pingTimeout: 10 maxPendingMessages: 100 @@ -93,3 +95,10 @@ feedbackListeners: maxRetries: 3 database: push connectionTimeout: 100 +rateLimiter: + redis: + host: "localhost" + port: 6379 + password: "" + tls: + disabled: false \ No newline at end of file diff --git a/config/docker_test.yaml b/config/docker_test.yaml index 7ab3cac..fe22f2f 100644 --- a/config/docker_test.yaml +++ b/config/docker_test.yaml @@ -1,6 +1,7 @@ --- gracefulShutdownTimeout: 10 apns: + rateLimit.rpm: 100 concurrentWorkers: 300 connectionPoolSize: 1 logStatsInterval: 750 @@ -14,6 +15,7 @@ apns: responsechannelsize: 100 connectionpoolsize: 10 gcm: + rateLimit.rpm: 100 pingInterval: 30 pingTimeout: 10 maxPendingMessages: 3 @@ -93,4 +95,10 @@ feedbackListeners: maxRetries: 3 database: push connectionTimeout: 100 - +rateLimiter: + redis: + host: "redis" + port: 6379 + password: "" + tls: + disabled: true \ No newline at end of file diff --git a/config/test.yaml b/config/test.yaml index 7c8e9aa..3086589 100644 --- a/config/test.yaml +++ b/config/test.yaml @@ -1,6 +1,7 @@ --- gracefulShutdownTimeout: 10 apns: + rateLimit.rpm: 100 concurrentWorkers: 300 connectionPoolSize: 1 logStatsInterval: 750 @@ -12,6 +13,7 @@ apns: teamID: "ABC123DEFG" topic: "com.game.test" gcm: + rateLimit.rpm: 100 pingInterval: 30 pingTimeout: 10 maxPendingMessages: 3 @@ -91,4 +93,10 @@ feedbackListeners: maxRetries: 3 database: push connectionTimeout: 100 - +rateLimiter: + redis: + host: "localhost" + port: 6379 + password: "" + tls: + disabled: true \ No newline at end of file diff --git a/docker-compose-container-dev.yml b/docker-compose-container-dev.yml index a7d7a55..f65198e 100644 --- a/docker-compose-container-dev.yml +++ b/docker-compose-container-dev.yml @@ -90,3 +90,14 @@ services: KAFKA_CLUSTERS_0_BOOTSTRAPSERVERS: kafka:9092 KAFKA_CLUSTERS_0_METRICS_PORT: 9997 DYNAMIC_CONFIG_ENABLED: 'true' + + redis: + image: redis:6.0.9-alpine + container_name: redis + ports: + - "6379:6379" + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 10s + timeout: 5s + retries: 3 \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index f695b48..e38a5d0 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -96,3 +96,13 @@ services: condition: service_healthy environment: KAFKA_BROKERS: kafka:9092 + + redis: + image: redis:6.0.9-alpine + ports: + - "6379:6379" + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 10s + timeout: 5s + retries: 3 \ No newline at end of file diff --git a/e2e/fcm_e2e_test.go b/e2e/fcm_e2e_test.go index db5dbb0..c7cf7bf 100644 --- a/e2e/fcm_e2e_test.go +++ b/e2e/fcm_e2e_test.go @@ -71,6 +71,9 @@ func (s *FcmE2ETestSuite) setupFcmPusher(appName string) (*firebaseMock.MockPush statsReport, err := extensions.NewStatsD(s.vConfig, logger, statsdClientMock) s.Require().NoError(err) + limit := s.vConfig.GetInt("gcm.rateLimit.rpm") + rateLimiter := extensions.NewRateLimiter(limit, s.vConfig, []interfaces.StatsReporter{statsReport}, logger) + pushClient := firebaseMock.NewMockPushClient(ctrl) gcmPusher.MessageHandler = map[string]interfaces.MessageHandler{ appName: handler.NewMessageHandler( @@ -78,6 +81,7 @@ func (s *FcmE2ETestSuite) setupFcmPusher(appName string) (*firebaseMock.MockPush pushClient, []interfaces.FeedbackReporter{}, []interfaces.StatsReporter{statsReport}, + rateLimiter, logger, s.config.GCM.ConcurrentWorkers, ), diff --git a/extensions/apns_message_handler.go b/extensions/apns_message_handler.go index bb993c6..9890cda 100644 --- a/extensions/apns_message_handler.go +++ b/extensions/apns_message_handler.go @@ -79,6 +79,7 @@ type APNSMessageHandler struct { consumptionManager interfaces.ConsumptionManager retryInterval time.Duration maxRetryAttempts uint + rateLimiter interfaces.RateLimiter } var _ interfaces.MessageHandler = &APNSMessageHandler{} @@ -94,6 +95,7 @@ func NewAPNSMessageHandler( feedbackReporters []interfaces.FeedbackReporter, pushQueue interfaces.APNSPushQueue, consumptionManager interfaces.ConsumptionManager, + rateLimiter interfaces.RateLimiter, ) (*APNSMessageHandler, error) { a := &APNSMessageHandler{ authKeyPath: authKeyPath, @@ -117,6 +119,7 @@ func NewAPNSMessageHandler( requestsHeap: NewTimeoutHeap(config), PushQueue: pushQueue, consumptionManager: consumptionManager, + rateLimiter: rateLimiter, } if a.Logger != nil { @@ -216,7 +219,7 @@ func (a *APNSMessageHandler) CleanMetadataCache() { } // HandleMessages get messages from msgChan and send to APNS. -func (a *APNSMessageHandler) HandleMessages(_ context.Context, message interfaces.KafkaMessage) { +func (a *APNSMessageHandler) HandleMessages(ctx context.Context, message interfaces.KafkaMessage) { l := a.Logger.WithFields(log.Fields{ "method": "HandleMessages", "jsonValue": string(message.Value), @@ -227,6 +230,14 @@ func (a *APNSMessageHandler) HandleMessages(_ context.Context, message interface if err != nil { return } + + allowed := a.rateLimiter.Allow(ctx, notification.DeviceToken, a.appName, "apns") + if !allowed { + statsReporterNotificationRateLimitReached(a.StatsReporters, a.appName, "apns") + l.WithField("message", message).Warn("rate limit reached") + return + } + if err := a.sendNotification(notification); err != nil { return } diff --git a/extensions/apns_message_handler_test.go b/extensions/apns_message_handler_test.go index 571b29b..ac8d01c 100644 --- a/extensions/apns_message_handler_test.go +++ b/extensions/apns_message_handler_test.go @@ -26,11 +26,12 @@ import ( "context" "encoding/json" "fmt" + "os" + "time" + uuid "github.com/satori/go.uuid" "github.com/sideshow/apns2" mock_interfaces "github.com/topfreegames/pusher/mocks/interfaces" - "os" - "time" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -51,6 +52,7 @@ var _ = FDescribe("APNS Message Handler", func() { var mockStatsDClient *mocks.StatsDClientMock var statsClients []interfaces.StatsReporter mockConsumptionManager := mock_interfaces.NewMockConsumptionManager() + mockRateLimiter := mocks.NewRateLimiterMock() ctx := context.Background() configFile := os.Getenv("CONFIG_FILE") @@ -98,6 +100,7 @@ var _ = FDescribe("APNS Message Handler", func() { feedbackClients, mockPushQueue, mockConsumptionManager, + mockRateLimiter, ) Expect(err).NotTo(HaveOccurred()) db.(*mocks.PGMock).RowsReturned = 0 @@ -668,6 +671,7 @@ var _ = FDescribe("APNS Message Handler", func() { feedbackClients, mockPushQueue, mockConsumptionManager, + mockRateLimiter, ) Expect(err).NotTo(HaveOccurred()) }) @@ -880,6 +884,7 @@ var _ = FDescribe("APNS Message Handler", func() { nil, nil, nil, + mockRateLimiter, ) Expect(err).NotTo(HaveOccurred()) hook.Reset() diff --git a/extensions/common.go b/extensions/common.go index 8f7f16b..b933098 100644 --- a/extensions/common.go +++ b/extensions/common.go @@ -88,6 +88,18 @@ func statsReporterHandleNotificationFailure( } } +func statsReporterNotificationRateLimitReached(statsReporters []interfaces.StatsReporter, game string, platform string) { + for _, statsReporter := range statsReporters { + statsReporter.NotificationRateLimitReached(game, platform) + } +} + +func statsReporterNotificationRateLimitFailed(statsReporters []interfaces.StatsReporter, game string, platform string) { + for _, statsReporter := range statsReporters { + statsReporter.NotificationRateLimitFailed(game, platform) + } +} + func statsReporterReportSendNotificationLatency(statsReporters []interfaces.StatsReporter, latencyMs time.Duration, game string, platform string, labels ...string) { for _, statsReporter := range statsReporters { statsReporter.ReportSendNotificationLatency(latencyMs, game, platform, labels...) diff --git a/extensions/datadog_statsd.go b/extensions/datadog_statsd.go index fdf1a3f..fa4a6b2 100644 --- a/extensions/datadog_statsd.go +++ b/extensions/datadog_statsd.go @@ -111,6 +111,24 @@ func (s *StatsD) HandleNotificationFailure(game string, platform string, err *er ) } +// NotificationRateLimitReached stores how many times rate limits were reached for the devices +func (s *StatsD) NotificationRateLimitReached(game string, platform string) { + s.Client.Incr( + "rate_limit_reached", + []string{fmt.Sprintf("platform:%s", platform), fmt.Sprintf("game:%s", game)}, + 1, + ) +} + +// NotificationRateLimitFailed stores how many times rate limits failed to be calculated +func (s *StatsD) NotificationRateLimitFailed(game string, platform string) { + s.Client.Incr( + "rate_limit_failed", + []string{fmt.Sprintf("platform:%s", platform), fmt.Sprintf("game:%s", game)}, + 1, + ) +} + // InitializeFailure notifu error when is impossible tho initilizer an app func (s *StatsD) InitializeFailure(game string, platform string) { s.Client.Incr("initialize_failure", []string{fmt.Sprintf("platform:%s", platform), fmt.Sprintf("game:%s", game)}, 1) diff --git a/extensions/gcm_message_handler.go b/extensions/gcm_message_handler.go index 330cb06..822a73c 100644 --- a/extensions/gcm_message_handler.go +++ b/extensions/gcm_message_handler.go @@ -76,6 +76,7 @@ type GCMMessageHandler struct { requestsHeap *TimeoutHeap CacheCleaningInterval int IsProduction bool + rateLimiter interfaces.RateLimiter } // NewGCMMessageHandler returns a new instance of a GCMMessageHandler @@ -87,6 +88,7 @@ func NewGCMMessageHandler( pendingMessagesWG *sync.WaitGroup, statsReporters []interfaces.StatsReporter, feedbackReporters []interfaces.FeedbackReporter, + rateLimiter interfaces.RateLimiter, ) (*GCMMessageHandler, error) { l := logger.WithFields(logrus.Fields{ "method": "NewGCMMessageHandler", @@ -94,7 +96,7 @@ func NewGCMMessageHandler( "isProduction": isProduction, }) - h, err := NewGCMMessageHandlerWithClient(game, isProduction, config, l.Logger, pendingMessagesWG, statsReporters, feedbackReporters, nil) + h, err := NewGCMMessageHandlerWithClient(game, isProduction, config, l.Logger, pendingMessagesWG, statsReporters, feedbackReporters, nil, rateLimiter) if err != nil { l.WithError(err).Error("Failed to create a new GCM Message handler.") return nil, err @@ -111,6 +113,7 @@ func NewGCMMessageHandlerWithClient( statsReporters []interfaces.StatsReporter, feedbackReporters []interfaces.FeedbackReporter, client interfaces.GCMClient, + rateLimiter interfaces.RateLimiter, ) (*GCMMessageHandler, error) { l := logger.WithFields(logrus.Fields{ "method": "NewGCMMessageHandlerWithClient", @@ -131,6 +134,7 @@ func NewGCMMessageHandlerWithClient( requestsHeap: NewTimeoutHeap(config), StatsReporters: statsReporters, GCMClient: client, + rateLimiter: rateLimiter, } err := g.configure() @@ -340,6 +344,13 @@ func (g *GCMMessageHandler) sendMessage(message interfaces.KafkaMessage) error { } l = l.WithField("message", km) + + allowed := g.rateLimiter.Allow(context.Background(), km.To, message.Game, "gcm") + if !allowed { + statsReporterNotificationRateLimitReached(g.StatsReporters, message.Game, "gcm") + l.WithField("message", message).Warn("rate limit reached") + return errors.New("rate limit reached") + } l.Debug("sending message to gcm") var messageID string diff --git a/extensions/gcm_message_handler_test.go b/extensions/gcm_message_handler_test.go index 0b3dada..3e14321 100644 --- a/extensions/gcm_message_handler_test.go +++ b/extensions/gcm_message_handler_test.go @@ -24,13 +24,14 @@ package extensions import ( "encoding/json" + "os" + "testing" + "time" + "github.com/sirupsen/logrus" "github.com/spf13/viper" "github.com/stretchr/testify/suite" "github.com/topfreegames/pusher/config" - "os" - "testing" - "time" uuid "github.com/satori/go.uuid" "github.com/sirupsen/logrus/hooks/test" @@ -73,6 +74,7 @@ func (s *GCMMessageHandlerTestSuite) setupHandler() ( logger, _ := test.NewNullLogger() mockClient := mocks.NewGCMClientMock() mockStatsdClient := mocks.NewStatsDClientMock() + mockRateLimiter := mocks.NewRateLimiterMock() statsD, err := NewStatsD(s.vConfig, logger, mockStatsdClient) s.Require().NoError(err) @@ -92,6 +94,7 @@ func (s *GCMMessageHandlerTestSuite) setupHandler() ( statsClients, feedbackClients, mockClient, + mockRateLimiter, ) s.NoError(err) s.Require().NotNil(handler) @@ -115,6 +118,7 @@ func (s *GCMMessageHandlerTestSuite) TestConfigureHandler() { nil, []interfaces.StatsReporter{}, []interfaces.FeedbackReporter{}, + nil, ) s.Error(err) s.Nil(handler) diff --git a/extensions/handler/message_handler.go b/extensions/handler/message_handler.go index ca9edd2..71afdd2 100644 --- a/extensions/handler/message_handler.go +++ b/extensions/handler/message_handler.go @@ -21,6 +21,7 @@ type messageHandler struct { statsMutex sync.Mutex feedbackReporters []interfaces.FeedbackReporter statsReporters []interfaces.StatsReporter + rateLimiter interfaces.RateLimiter statsDClient extensions.StatsD sendPushConcurrencyControl chan interface{} responsesChannel chan struct { @@ -36,6 +37,7 @@ func NewMessageHandler( client interfaces.PushClient, feedbackReporters []interfaces.FeedbackReporter, statsReporters []interfaces.StatsReporter, + rateLimiter interfaces.RateLimiter, logger *logrus.Logger, concurrentWorkers int, ) interfaces.MessageHandler { @@ -51,6 +53,7 @@ func NewMessageHandler( client: client, feedbackReporters: feedbackReporters, statsReporters: statsReporters, + rateLimiter: rateLimiter, logger: l.Logger, config: cfg, sendPushConcurrencyControl: make(chan interface{}, concurrentWorkers), @@ -88,6 +91,13 @@ func (h *messageHandler) HandleMessages(ctx context.Context, msg interfaces.Kafk return } + allowed := h.rateLimiter.Allow(ctx, km.To, msg.Game, "gcm") + if !allowed { + h.reportRateLimitReached(msg.Game) + l.WithField("message", msg).Warn("rate limit reached") + return + } + if km.Metadata != nil { if km.Message.Data == nil { km.Message.Data = map[string]interface{}{} @@ -237,6 +247,12 @@ func (h *messageHandler) reportFirebaseLatency(latency time.Duration) { } } +func (h *messageHandler) reportRateLimitReached(game string) { + for _, statsReporter := range h.statsReporters { + statsReporter.NotificationRateLimitReached(game, "gcm") + } +} + func translateToPushError(err error) *pushErrors.PushError { if pusherError, ok := err.(*pushErrors.PushError); ok { return pusherError diff --git a/extensions/handler/message_handler_test.go b/extensions/handler/message_handler_test.go index f0a49bd..c7c328e 100644 --- a/extensions/handler/message_handler_test.go +++ b/extensions/handler/message_handler_test.go @@ -69,6 +69,7 @@ func (s *MessageHandlerTestSuite) SetupSubTest() { statsClients := []interfaces.StatsReporter{statsD} feedbackClients := []interfaces.FeedbackReporter{kc} + mockRateLimiter := mocks.NewRateLimiterMock() cfg := newDefaultMessageHandlerConfig() cfg.concurrentResponseHandlers = concurrentWorkers @@ -77,6 +78,7 @@ func (s *MessageHandlerTestSuite) SetupSubTest() { client: s.mockClient, feedbackReporters: feedbackClients, statsReporters: statsClients, + rateLimiter: mockRateLimiter, logger: l, config: cfg, sendPushConcurrencyControl: make(chan interface{}, concurrentWorkers), diff --git a/extensions/rate_limiter.go b/extensions/rate_limiter.go new file mode 100644 index 0000000..a1b170a --- /dev/null +++ b/extensions/rate_limiter.go @@ -0,0 +1,102 @@ +package extensions + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "strconv" + "time" + + "github.com/redis/go-redis/v9" + "github.com/sirupsen/logrus" + "github.com/spf13/viper" + "github.com/topfreegames/pusher/interfaces" +) + +type rateLimiter struct { + redis *redis.Client + rpmLimit int + statsReporters []interfaces.StatsReporter + l *logrus.Entry +} + +func NewRateLimiter(limit int, config *viper.Viper, statsReporters []interfaces.StatsReporter, logger *logrus.Logger) rateLimiter { + host := config.GetString("rateLimiter.redis.host") + port := config.GetInt("rateLimiter.redis.port") + pwd := config.GetString("rateLimiter.redis.password") + disableTLS := config.GetBool("rateLimiter.tls.disabled") + + addr := fmt.Sprintf("%s:%d", host, port) + opts := &redis.Options{ + Addr: addr, + Password: pwd, + } + + // TLS for integration tests running in containers can raise connection errors. + // Not recommended to disable TLS for production. + if !disableTLS { + opts.TLSConfig = &tls.Config{} + } + + rdb := redis.NewClient(opts) + + return rateLimiter{ + redis: rdb, + rpmLimit: limit, + statsReporters: statsReporters, + l: logger.WithFields(logrus.Fields{ + "extension": "RateLimiter", + "rpmLimit": limit, + }), + } +} + +// Allow checks Redis for the current rate a given device has in the current minute +// If the rate is lower than the limit, the message is allowed. Otherwise, it is not allowed. +// Reference: https://redis.io/glossary/rate-limiting/ +func (r rateLimiter) Allow(ctx context.Context, device string, game string, platform string) bool { + deviceKey := fmt.Sprintf("%s:%d", device, time.Now().Minute()) + l := r.l.WithFields(logrus.Fields{ + "deviceKey": deviceKey, + "method": "Allow", + }) + + val, err := r.redis.Get(ctx, deviceKey).Result() + if err != nil && !errors.Is(err, redis.Nil) { + // Something went wrong, return true to avoid blocking notifications. + l.WithError(err).Error("could not get current rate in redis") + statsReporterNotificationRateLimitFailed(r.statsReporters, game, platform) + return true + } + if errors.Is(err, redis.Nil) { + // First time + val = "0" + } + + current, err := strconv.Atoi(val) + if err != nil { + // Something went wrong, return true to avoid blocking notifications. + l.WithError(err).Error("current rate is invalid") + statsReporterNotificationRateLimitFailed(r.statsReporters, game, platform) + return true + } + + if current >= r.rpmLimit { + return false + } + + _, err = r.redis.TxPipelined(ctx, func(pipe redis.Pipeliner) error { + pipe.Incr(ctx, deviceKey) + pipe.Expire(ctx, deviceKey, time.Minute) + return nil + }) + if err != nil { + // Allow the operation even if the transaction fails, to avoid blocking notifications. + l.WithError(err).Error("increment to current rate failed") + statsReporterNotificationRateLimitFailed(r.statsReporters, game, platform) + } + + l.WithField("currentRate", current).Debug("current rate allows message") + return true +} diff --git a/extensions/rate_limiter_test.go b/extensions/rate_limiter_test.go new file mode 100644 index 0000000..f7e821e --- /dev/null +++ b/extensions/rate_limiter_test.go @@ -0,0 +1,76 @@ +package extensions + +import ( + "context" + "fmt" + "os" + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "github.com/topfreegames/pusher/interfaces" + "github.com/topfreegames/pusher/util" + + "github.com/google/uuid" + "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus/hooks/test" +) + +var _ = FDescribe("Rate Limiter", func() { + Describe("[Integration]", func() { + logger, hook := test.NewNullLogger() + logger.Level = logrus.DebugLevel + configFile := os.Getenv("CONFIG_FILE") + if configFile == "" { + configFile = "../config/test.yaml" + } + config, err := util.NewViperWithConfigFile(configFile) + Expect(err).NotTo(HaveOccurred()) + hook.Reset() + game := "test" + platform := "test" + statsClients := []interfaces.StatsReporter{} + + Describe("Rate limiting", func() { + It("should return not-allowed when rate limit is reached", func() { + rl := NewRateLimiter(1, config, statsClients, logger) + ctx := context.Background() + device := uuid.NewString() + allowed := rl.Allow(ctx, device, game, platform) + Expect(allowed).To(BeTrue()) + + // Should not allow due to reaching limit of 1 + allowed = rl.Allow(ctx, device, game, platform) + Expect(allowed).To(BeFalse()) + }) + + It("should increment current rate if limit is not reached", func() { + rl := NewRateLimiter(10, config, statsClients, logger) + ctx := context.Background() + device := uuid.NewString() + currMin := time.Now().Minute() + + allowed := rl.Allow(ctx, device, game, platform) + Expect(allowed).To(BeTrue()) + + key := fmt.Sprintf("%s:%d", device, currMin) + actual, err := rl.redis.Get(ctx, key).Result() + Expect(err).ToNot(HaveOccurred()) + Expect(actual).To(BeEquivalentTo("1")) + }) + + It("should return allowed if redis fails", func() { + wrongConfig, err := util.NewViperWithConfigFile(configFile) + Expect(err).NotTo(HaveOccurred()) + wrongConfig.Set("rateLimiter.redis.host", "unreachable") + rl := NewRateLimiter(10, wrongConfig, statsClients, logger) + ctx := context.Background() + device := uuid.NewString() + + allowed := rl.Allow(ctx, device, game, platform) + Expect(allowed).To(BeTrue()) + }) + + }) + }) +}) diff --git a/go.mod b/go.mod index 0b172b7..8b8e280 100644 --- a/go.mod +++ b/go.mod @@ -7,8 +7,11 @@ require ( github.com/DataDog/datadog-go v0.0.0-20170427165718-0ddda6bee211 github.com/confluentinc/confluent-kafka-go/v2 v2.2.0 github.com/getsentry/raven-go v0.2.0 + github.com/google/uuid v1.3.0 + github.com/mitchellh/mapstructure v1.1.2 github.com/onsi/ginkgo v1.16.5 github.com/onsi/gomega v1.18.1 + github.com/redis/go-redis/v9 v9.5.1 github.com/satori/go.uuid v1.2.0 github.com/sideshow/apns2 v0.0.0-20170926093756-a3ce9c6f95f6 github.com/sirupsen/logrus v1.8.1 @@ -31,14 +34,15 @@ require ( cloud.google.com/go/storage v1.30.1 // indirect github.com/MicahParks/keyfunc v1.9.0 // indirect github.com/certifi/gocertifi v0.0.0-20200922220541-2c3bb06c6054 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgrijalva/jwt-go v3.2.0+incompatible // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/golang-jwt/jwt/v4 v4.5.0 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/google/go-cmp v0.5.9 // indirect - github.com/google/uuid v1.3.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect github.com/googleapis/gax-go/v2 v2.8.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect @@ -47,7 +51,6 @@ require ( github.com/jpillora/backoff v1.0.0 // indirect github.com/magiconair/properties v1.8.6 // indirect github.com/mattn/go-xmpp v0.0.0-20170423100754-906d9d747d2b // indirect - github.com/mitchellh/mapstructure v1.1.2 // indirect github.com/nxadm/tail v1.4.11 // indirect github.com/pborman/uuid v0.0.0-20170612153648-e790cca94e6c // indirect github.com/pelletier/go-toml v1.9.3 // indirect diff --git a/go.sum b/go.sum index 19966b1..6b100e5 100644 --- a/go.sum +++ b/go.sum @@ -693,6 +693,10 @@ github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dR github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/boombuler/barcode v1.0.1/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bshuster-repo/logrus-logstash-hook v0.4.1/go.mod h1:zsTqEiSzDgAa/8GZR7E1qaXrhYNDKBYy5/dWPTIflbk= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/buger/jsonparser v0.0.0-20180808090653-f4dd9f5a6b44/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/bugsnag/bugsnag-go v0.0.0-20141110184014-b1d153021fcd/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8= @@ -711,6 +715,7 @@ github.com/certifi/gocertifi v0.0.0-20200922220541-2c3bb06c6054/go.mod h1:sGbDF6 github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/checkpoint-restore/go-criu/v4 v4.1.0/go.mod h1:xUQBLp4RLc5zJtWY++yjOoMoB5lihDt7fai+75m+rGw= github.com/checkpoint-restore/go-criu/v5 v5.0.0/go.mod h1:cfwC0EG7HMUenopBsUf9d89JlCLQIfgVcNsNN0t6T2M= @@ -884,6 +889,7 @@ github.com/denverdino/aliyungo v0.0.0-20190125010748-a747050bb1ba/go.mod h1:dV8l github.com/dgrijalva/jwt-go v0.0.0-20170104182250-a601269ab70c/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= github.com/dnaeon/go-vcr v1.0.1/go.mod h1:aBB1+wY4s93YsC3HHjMBMrwTj2R9FHDzUr9KyGc8n1E= @@ -1461,6 +1467,8 @@ github.com/prometheus/procfs v0.2.0/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4O github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= +github.com/redis/go-redis/v9 v9.5.1 h1:H1X4D3yHPaYrkL5X06Wh6xNVM/pX0Ft4RV0vMGvLBh8= +github.com/redis/go-redis/v9 v9.5.1/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/clock v0.0.0-20190514195947-2896927a307a/go.mod h1:4r5QyqhjIWCcK8DO4KMclc5Iknq5qVBAlbYYzAbUScQ= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= diff --git a/interfaces/rate_limiter.go b/interfaces/rate_limiter.go new file mode 100644 index 0000000..b605886 --- /dev/null +++ b/interfaces/rate_limiter.go @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2016 TFG Co + * Author: TFG Co + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of + * this software and associated documentation files (the "Software"), to deal in + * the Software without restriction, including without limitation the rights to + * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of + * the Software, and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +package interfaces + +import "context" + +// RateLimiter interface for rate limiting notifications per device. +type RateLimiter interface { + Allow(ctx context.Context, device, game, platform string) bool +} diff --git a/interfaces/stats_reporter.go b/interfaces/stats_reporter.go index 79b8e29..182bf24 100644 --- a/interfaces/stats_reporter.go +++ b/interfaces/stats_reporter.go @@ -23,8 +23,9 @@ package interfaces import ( - "github.com/topfreegames/pusher/errors" "time" + + "github.com/topfreegames/pusher/errors" ) // StatsReporter interface for making stats reporters pluggable easily. @@ -36,6 +37,8 @@ type StatsReporter interface { ReportGoStats(numGoRoutines int, allocatedAndNotFreed, heapObjects, nextGCBytes, pauseGCNano uint64) ReportMetricGauge(metric string, value float64, game string, platform string) ReportMetricCount(metric string, value int64, game string, platform string) + NotificationRateLimitReached(game string, platform string) + NotificationRateLimitFailed(game string, platform string) ReportSendNotificationLatency(latencyMs time.Duration, game string, platform string, labels ...string) ReportFirebaseLatency(latencyMs time.Duration, game string, labels ...string) } diff --git a/mocks/rate_limiter.go b/mocks/rate_limiter.go new file mode 100644 index 0000000..a05a6a6 --- /dev/null +++ b/mocks/rate_limiter.go @@ -0,0 +1,14 @@ +package mocks + +import "context" + +type rateLimiterMock struct { +} + +func NewRateLimiterMock() *rateLimiterMock { + return &rateLimiterMock{} +} + +func (rl *rateLimiterMock) Allow(ctx context.Context, device, game, platform string) bool { + return true +} diff --git a/pusher/apns.go b/pusher/apns.go index ccdada2..f1d9e88 100644 --- a/pusher/apns.go +++ b/pusher/apns.go @@ -25,6 +25,7 @@ package pusher import ( "errors" "fmt" + "github.com/sirupsen/logrus" "github.com/spf13/viper" "github.com/topfreegames/pusher/config" @@ -97,6 +98,7 @@ func (a *APNSPusher) configure(queue interfaces.APNSPushQueue, db interfaces.DB, keyID := a.ViperConfig.GetString("apns.certs." + k + ".keyID") teamID := a.ViperConfig.GetString("apns.certs." + k + ".teamID") topic := a.ViperConfig.GetString("apns.certs." + k + ".topic") + rateLimit := a.ViperConfig.GetInt("apns.rateLimit.rpm") l.WithFields(logrus.Fields{ "app": k, @@ -119,6 +121,7 @@ func (a *APNSPusher) configure(queue interfaces.APNSPushQueue, db interfaces.DB, a.feedbackReporters, queue, interfaces.ConsumptionManager(q), + extensions.NewRateLimiter(rateLimit, a.ViperConfig, a.StatsReporters, l.Logger), ) if err == nil { a.MessageHandler[k] = handler diff --git a/pusher/gcm.go b/pusher/gcm.go index ef44c78..08303b2 100644 --- a/pusher/gcm.go +++ b/pusher/gcm.go @@ -98,6 +98,8 @@ func (g *GCMPusher) createMessageHandlerForApps(ctx context.Context) error { g.MessageHandler = make(map[string]interfaces.MessageHandler) for _, app := range g.Config.GetGcmAppsArray() { credentials := g.ViperConfig.GetString("gcm.firebaseCredentials." + app) + rateLimit := g.ViperConfig.GetInt("gcm.rateLimit.rpm") + l = l.WithField("app", app) if credentials != "" { // Firebase is configured, use new handler pushClient, err := client.NewFirebaseClient(ctx, credentials, g.Logger) @@ -111,6 +113,7 @@ func (g *GCMPusher) createMessageHandlerForApps(ctx context.Context) error { pushClient, g.feedbackReporters, g.StatsReporters, + extensions.NewRateLimiter(rateLimit, g.ViperConfig, g.StatsReporters, l.Logger), g.Logger, g.Config.GCM.ConcurrentWorkers, ) @@ -123,6 +126,7 @@ func (g *GCMPusher) createMessageHandlerForApps(ctx context.Context) error { g.Queue.PendingMessagesWaitGroup(), g.StatsReporters, g.feedbackReporters, + extensions.NewRateLimiter(rateLimit, g.ViperConfig, g.StatsReporters, l.Logger), ) if err != nil { l.WithError(err).Error("could not create gcm message handler")