From 44b6a41d18d7246140e58f74722c3a0437bc90a7 Mon Sep 17 00:00:00 2001 From: Aditya Thebe Date: Wed, 11 Sep 2024 12:58:08 +0545 Subject: [PATCH 1/2] feat: allow custom route extractor on pg notify router --- postq/pg/router.go | 51 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/postq/pg/router.go b/postq/pg/router.go index 29188ee9..0b88bb36 100644 --- a/postq/pg/router.go +++ b/postq/pg/router.go @@ -6,21 +6,49 @@ import ( "github.com/flanksource/duty/context" ) +type routeExtractorFn func(string) (string, string, error) + +func defaultRouteExtractor(payload string) (string, string, error) { + // The original payload is expected to be in the form of + // <...optional payload> + fields := strings.Fields(payload) + route := fields[0] + derivedPayload := strings.Join(fields[1:], " ") + return route, derivedPayload, nil +} + // notifyRouter distributes the pgNotify event to multiple channels // based on the payload. type notifyRouter struct { - registry map[string]chan string + registry map[string]chan string + routeExtractor routeExtractorFn } func NewNotifyRouter() *notifyRouter { return ¬ifyRouter{ - registry: make(map[string]chan string), + registry: make(map[string]chan string), + routeExtractor: defaultRouteExtractor, } } +func (t *notifyRouter) WithRouteExtractor(routeExtractor routeExtractorFn) *notifyRouter { + t.routeExtractor = routeExtractor + return t +} + // RegisterRoutes creates a single channel for the given routes and returns it. func (t *notifyRouter) RegisterRoutes(routes ...string) <-chan string { + // If any of the routes already has a channel, we use that + // for all the routes. + // Caution: The caller needs to ensure that the route + // groups do not overlap. pgNotifyChannel := make(chan string) + for _, we := range routes { + if existing, ok := t.registry[we]; ok { + pgNotifyChannel = existing + } + } + for _, we := range routes { t.registry[we] = pgNotifyChannel } @@ -33,18 +61,23 @@ func (t *notifyRouter) Run(ctx context.Context, channel string) { go Listen(ctx, channel, eventQueueNotifyChannel) for payload := range eventQueueNotifyChannel { - if _, ok := t.registry[payload]; !ok || payload == "" { + if payload == "" { continue } - // The original payload is expected to be in the form of - // <...optional payload> - fields := strings.Fields(payload) - route := fields[0] - derivedPayload := strings.Join(fields[1:], " ") + route, extractedPayload, err := t.routeExtractor(payload) + if err != nil { + continue + } + + if _, ok := t.registry[route]; !ok { + continue + } if ch, ok := t.registry[route]; ok { - ch <- derivedPayload + go func() { + ch <- extractedPayload + }() } } } From dee89e6b07092533b02ab5a84466ac883a4ab5ea Mon Sep 17 00:00:00 2001 From: Aditya Thebe Date: Wed, 11 Sep 2024 20:46:42 +0545 Subject: [PATCH 2/2] feat: test for pg router --- postq/pg/router.go | 6 +++- postq/pg/router_test.go | 63 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 postq/pg/router_test.go diff --git a/postq/pg/router.go b/postq/pg/router.go index 0b88bb36..3072e096 100644 --- a/postq/pg/router.go +++ b/postq/pg/router.go @@ -60,7 +60,11 @@ func (t *notifyRouter) Run(ctx context.Context, channel string) { eventQueueNotifyChannel := make(chan string) go Listen(ctx, channel, eventQueueNotifyChannel) - for payload := range eventQueueNotifyChannel { + t.start(eventQueueNotifyChannel) +} + +func (t *notifyRouter) start(channel chan string) { + for payload := range channel { if payload == "" { continue } diff --git a/postq/pg/router_test.go b/postq/pg/router_test.go new file mode 100644 index 00000000..f480d62c --- /dev/null +++ b/postq/pg/router_test.go @@ -0,0 +1,63 @@ +package pg + +import ( + "sync" + "testing" + "time" +) + +func TestPGRouter(t *testing.T) { + // Create & run the router + r := NewNotifyRouter() + pgNotifyChan := make(chan string) + go func() { + r.start(pgNotifyChan) + }() + + // Two subscribers + alpha := r.RegisterRoutes("alphaA", "alphaB") + beta := r.RegisterRoutes("beta") + + var alphaCount, betaCount int + timeout := time.NewTimer(time.Second * 3) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + + for { + select { + case <-alpha: + alphaCount++ + if alphaCount+betaCount == 3 { + return + } + + case <-beta: + betaCount++ + if alphaCount+betaCount == 3 { + return + } + + case <-timeout.C: + return + } + } + }() + + // Simulate receiving pg notify + go func() { + pgNotifyChan <- "alphaA 1" + pgNotifyChan <- "beta 1" + pgNotifyChan <- "alphaB 1" + }() + + wg.Wait() + if alphaCount != 2 { + t.Errorf("Expected alphaCount to be 2, got %d", alphaCount) + } + + if betaCount != 1 { + t.Errorf("Expected betaCount to be 1, got %d", betaCount) + } +}