Skip to content

Commit

Permalink
Reconnect RPC client on shutdown (#133)
Browse files Browse the repository at this point in the history
* feat: Reconnect RPC client on shutdown

* feat: Separate rpc message sending from RPC client handling

* fix: Try reinitializing rpc client on each tick

* feat: Only resend on a tick

* feat: Add more logs to rpc initialization

* refactor: Refactor onTick into cases

* fix: handle rpc error as goroutine

* fix: Only spawn goroutine on shutdown

* fix: Remove empty default case
  • Loading branch information
Hyodar authored Apr 30, 2024
1 parent 66193e9 commit c285cff
Showing 1 changed file with 62 additions and 26 deletions.
88 changes: 62 additions & 26 deletions operator/rpc_client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package operator

import (
"errors"
"fmt"
"net/rpc"
"sync"
Expand Down Expand Up @@ -29,7 +30,7 @@ const (
type RpcMessage = interface{}

type AggregatorRpcClient struct {
rpcClientLock sync.Mutex
rpcClientLock sync.RWMutex
rpcClient *rpc.Client
aggregatorIpPortAddr string

Expand Down Expand Up @@ -71,45 +72,77 @@ func (c *AggregatorRpcClient) WithMetrics(registry *prometheus.Registry) error {
}

func (c *AggregatorRpcClient) dialAggregatorRpcClient() error {
c.rpcClientLock.Lock()
defer c.rpcClientLock.Unlock()

if c.rpcClient != nil {
return nil
}

c.logger.Info("rpc client is nil. Dialing aggregator rpc client")

client, err := rpc.DialHTTP("tcp", c.aggregatorIpPortAddr)
if err != nil {
c.logger.Error("Error dialing aggregator rpc client", "err", err)
return err
}

c.rpcClient = client

return nil
}

func (c *AggregatorRpcClient) InitializeClientIfNotExist() error {
c.rpcClientLock.Lock()
defer c.rpcClientLock.Unlock()

c.rpcClientLock.RLock()
if c.rpcClient != nil {
c.rpcClientLock.RUnlock()
return nil
}
c.rpcClientLock.RUnlock()

return c.dialAggregatorRpcClient()
}

func (c *AggregatorRpcClient) handleRpcError(err error) error {
if err == rpc.ErrShutdown {
go c.handleRpcShutdown()
}

return nil
}

func (c *AggregatorRpcClient) handleRpcShutdown() {
c.rpcClientLock.Lock()
defer c.rpcClientLock.Unlock()

if c.rpcClient != nil {
c.logger.Info("Closing RPC client due to shutdown")

err := c.rpcClient.Close()
if err != nil {
c.logger.Error("Error closing RPC client", "err", err)
}

c.rpcClient = nil
}
}

func (c *AggregatorRpcClient) onTick() {
tickerC := c.resendTicker.C
for {
// TODO(edwin): handle closed chan
<-tickerC

{
c.unsentMessagesLock.Lock()
if len(c.unsentMessages) == 0 {
c.unsentMessagesLock.Unlock()
continue
}
c.unsentMessagesLock.Unlock()
}
<-c.resendTicker.C

err := c.InitializeClientIfNotExist()
if err != nil {
c.logger.Error("Error initializing client", "err", err)
continue
}

c.unsentMessagesLock.Lock()
if len(c.unsentMessages) == 0 {
c.unsentMessagesLock.Unlock()
continue
}
c.unsentMessagesLock.Unlock()

c.tryResendFromDeque()
}
Expand Down Expand Up @@ -164,39 +197,42 @@ func (c *AggregatorRpcClient) tryResendFromDeque() {
}

func (c *AggregatorRpcClient) sendOperatorMessage(sendCb func() error, message RpcMessage) {
c.rpcClientLock.RLock()
defer c.rpcClientLock.RUnlock()

appendProtected := func() {
c.unsentMessagesLock.Lock()
c.unsentMessages = append(c.unsentMessages, message)
c.unsentMessagesLock.Unlock()
}

err := c.InitializeClientIfNotExist()
if err != nil {
if c.rpcClient == nil {
appendProtected()
return
}

c.logger.Info("Sending request to aggregator")
err = sendCb()
err := sendCb()
if err != nil {
c.handleRpcError(err)
appendProtected()
return
}

c.tryResendFromDeque()
}

func (c *AggregatorRpcClient) sendRequest(sendCb func() error) error {
err := c.InitializeClientIfNotExist()
if err != nil {
c.logger.Error("Could not reinitialize RPC client")
return err
c.rpcClientLock.RLock()
defer c.rpcClientLock.RUnlock()

if c.rpcClient == nil {
return errors.New("rpc client is nil")
}

c.logger.Info("Sending request to aggregator")

err = sendCb()
err := sendCb()
if err != nil {
c.handleRpcError(err)
return err
}

Expand Down

0 comments on commit c285cff

Please sign in to comment.