diff --git a/connection_provider.go b/connection_provider.go index 8d9de6a..ea6d253 100644 --- a/connection_provider.go +++ b/connection_provider.go @@ -65,7 +65,11 @@ func newConnection(conn grpcClientConn) *connection { } func (conn *connection) close() error { - return conn.grpcConn.Close() + if conn.grpcConn != nil { + return conn.grpcConn.Close() + } + + return nil } // connectionAndEndpoints represents a combination of a gRPC client connection and server endpoints. @@ -354,7 +358,6 @@ func (cp *connectionProvider) connectToSeeds(ctx context.Context) error { grpcConn, err := cp.grpcConnFactory(seed) if err != nil { logger.ErrorContext(ctx, "failed to create connection", slog.Any("error", err)) - grpcConn.Close() return } diff --git a/connection_provider_test.go b/connection_provider_test.go index 647e5da..a4265d5 100644 --- a/connection_provider_test.go +++ b/connection_provider_test.go @@ -170,7 +170,85 @@ func TestGetSeedConn_FailSeedConnEmpty(t *testing.T) { assert.Equal(t, errors.New("no seed connections found"), err) } -func TestconnectToSeeds(t *testing.T) { +func TestConnectToSeeds_FailedAlreadyConnected(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cp := &connectionProvider{ + isLoadBalancer: true, + closed: atomic.Bool{}, + logger: slog.Default(), + } + + cp.seedConns = []*connection{ + { + grpcConn: NewMockgrpcClientConn(ctrl), + }, + } + + err := cp.connectToSeeds(context.Background()) + + assert.Equal(t, errors.New("seed connections already exist, close them first"), err) +} + +func TestConnectToSeeds_FailedFailedToCreateConnection(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cp := &connectionProvider{ + isLoadBalancer: true, + closed: atomic.Bool{}, + logger: slog.Default(), + grpcConnFactory: func(hostPort *HostPort) (grpcClientConn, error) { + return nil, fmt.Errorf("foo") + }, + } + + cp.seeds = HostPortSlice{ + &HostPort{ + Host: "host", + Port: 3000, + }, + } + + err := cp.connectToSeeds(context.Background()) + + assert.Equal(t, NewAVSError("failed to connect to seeds", nil), err) +} + +func TestConnectToSeeds_FailedToRefreshToken(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockToken := NewMocktokenManager(ctrl) + mockToken. + EXPECT(). + RefreshToken(gomock.Any(), gomock.Any()). + Return(fmt.Errorf("foo")) + + cp := &connectionProvider{ + isLoadBalancer: true, + closed: atomic.Bool{}, + logger: slog.Default(), + grpcConnFactory: func(hostPort *HostPort) (grpcClientConn, error) { + return nil, nil + }, + connFactory: func(conn grpcClientConn) *connection { + return &connection{} + }, + token: mockToken, + } + + cp.seeds = HostPortSlice{ + &HostPort{ + Host: "host", + Port: 3000, + }, + } + + err := cp.connectToSeeds(context.Background()) + + assert.Equal(t, NewAVSError("failed to connect to seeds", fmt.Errorf("foo")), err) } func TestUpdateClusterConns_NoNewClusterID(t *testing.T) { @@ -253,7 +331,7 @@ func TestUpdateClusterConns_NoNewClusterID(t *testing.T) { assert.Len(t, cp.nodeConns, 2) } -func TestUpdateClusterConns_NewClusterID(t *testing.T) { +func TestUpdateClusterConns_NewClusterIDWithDIFFERENTNodeIDs(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -366,11 +444,44 @@ func TestUpdateClusterConns_NewClusterID(t *testing.T) { Target(). Return("") + expectedClusterID := uint64(456) + // After a new cluster is discovered we expect these to be the new nodeConns + expectedNewNodeConns := map[uint64]*connectionAndEndpoints{ + 3: { + conn: &connection{ + clusterInfoClient: mockClusterInfoClient1111, + aboutClient: mockAboutClient1111, + }, + endpoints: &protos.ServerEndpointList{ + Endpoints: []*protos.ServerEndpoint{ + { + Address: "1.1.1.1", + Port: 3000, + }, + }, + }, + }, + 4: { + conn: &connection{ + clusterInfoClient: mockClusterInfoClient2222, + aboutClient: mockAboutClient2222, + }, + endpoints: &protos.ServerEndpointList{ + Endpoints: []*protos.ServerEndpoint{ + { + Address: "2.2.2.2", + Port: 3000, + }, + }, + }, + }, + } + mockClusterInfoClient2. EXPECT(). GetClusterId(gomock.Any(), gomock.Any()). Return(&protos.ClusterId{ - Id: 456, + Id: expectedClusterID, }, nil) mockClusterInfoClient2. @@ -420,9 +531,183 @@ func TestUpdateClusterConns_NewClusterID(t *testing.T) { }, } + cp.logger.Debug("Running updateClusterConns") + + cp.updateClusterConns(ctx) + + assert.Equal(t, expectedClusterID, cp.clusterID) + assert.Len(t, cp.nodeConns, 2) + + for k, v := range cp.nodeConns { + assert.EqualExportedValues(t, expectedNewNodeConns[k].endpoints, v.endpoints) + } +} + +func TestUpdateClusterConns_NewClusterIDWithSAMENodeIDs(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + grpcConn1 := NewMockgrpcClientConn(ctrl) + mockClusterInfoClient1 := protos.NewMockClusterInfoServiceClient(ctrl) + grpcConn2 := NewMockgrpcClientConn(ctrl) + mockClusterInfoClient2 := protos.NewMockClusterInfoServiceClient(ctrl) + + expectedClusterID := uint64(456) + + mockClusterInfoClient2. + EXPECT(). + GetClusterId(gomock.Any(), gomock.Any()). + Return(&protos.ClusterId{ + Id: expectedClusterID, + }, nil) + + mockClusterInfoClient2. + EXPECT(). + GetClusterEndpoints(gomock.Any(), gomock.Any()). + Return(&protos.ClusterNodeEndpoints{ + Endpoints: map[uint64]*protos.ServerEndpointList{ // larger, so the cluster id 456 will win + 1: { + Endpoints: []*protos.ServerEndpoint{ + { + Address: "1.1.1.1", + Port: 3000, + }, + }, + }, + 2: { + Endpoints: []*protos.ServerEndpoint{ + { + Address: "2.2.2.2", + Port: 3000, + }, + }, + }, + }, + }, nil) + + mockNewGrpcConn1111 := NewMockgrpcClientConn(ctrl) + mockNewGrpcConn2222 := NewMockgrpcClientConn(ctrl) + + mockClusterInfoClient1111 := protos.NewMockClusterInfoServiceClient(ctrl) + mockClusterInfoClient2222 := protos.NewMockClusterInfoServiceClient(ctrl) + + mockAboutClient1111 := protos.NewMockAboutServiceClient(ctrl) + mockAboutClient2222 := protos.NewMockAboutServiceClient(ctrl) + + mockAboutClient1111. + EXPECT(). + Get(gomock.Any(), gomock.Any()). + Return(nil, nil) + + mockAboutClient2222. + EXPECT(). + Get(gomock.Any(), gomock.Any()). + Return(nil, nil) + + cp := &connectionProvider{ + logger: slog.Default(), + seedConns: []*connection{}, + tlsConfig: &tls.Config{}, + seeds: HostPortSlice{}, + nodeConnsLock: &sync.RWMutex{}, + tendInterval: time.Second * 1, + clusterID: 123, + listenerName: nil, + isLoadBalancer: false, + token: nil, + stopTendChan: make(chan struct{}), + closed: atomic.Bool{}, + grpcConnFactory: func(hostPort *HostPort) (grpcClientConn, error) { + if hostPort.String() == "1.1.1.1:3000" { + return mockNewGrpcConn1111, nil + } else if hostPort.String() == "2.2.2.2:3000" { + return mockNewGrpcConn2222, nil + } + + return nil, fmt.Errorf("foo") + }, + connFactory: func(grpcConn grpcClientConn) *connection { + if grpcConn == mockNewGrpcConn1111 { + return &connection{ + clusterInfoClient: mockClusterInfoClient1111, + aboutClient: mockAboutClient1111, + } + } else if grpcConn == mockNewGrpcConn2222 { + return &connection{ + clusterInfoClient: mockClusterInfoClient2222, + aboutClient: mockAboutClient2222, + } + } + + return nil + }, + // Existing node connections. These will be replaced after a new cluster is found. + nodeConns: map[uint64]*connectionAndEndpoints{ + 1: { + conn: &connection{ + grpcConn: grpcConn1, + clusterInfoClient: mockClusterInfoClient1, + }, + endpoints: &protos.ServerEndpointList{}, + }, + 2: { + conn: &connection{ + grpcConn: grpcConn2, + clusterInfoClient: mockClusterInfoClient2, + }, + endpoints: &protos.ServerEndpointList{}, + }, + }, + } + + cp.logger = cp.logger.With(slog.String("test", "TestUpdateClusterConns_NewClusterID")) + + cp.logger.Debug("Setting up existing node connections") + + grpcConn1. + EXPECT(). + Target(). + Return("") + + mockClusterInfoClient1. + EXPECT(). + GetClusterId(gomock.Any(), gomock.Any()). + Return(&protos.ClusterId{ + Id: 789, // Different cluster id from client 2 + }, nil) + + mockClusterInfoClient1. + EXPECT(). + GetClusterEndpoints(gomock.Any(), gomock.Any()). + Return(&protos.ClusterNodeEndpoints{ + Endpoints: map[uint64]*protos.ServerEndpointList{ // Smaller num of endpoints from client 2 + 0: { + Endpoints: []*protos.ServerEndpoint{ + { + Address: "1.1.1.1", + Port: 3000, + }, + }, + }, + }, + }, nil) + + grpcConn1. + EXPECT(). + Close(). + Return(nil) + + grpcConn2. + EXPECT(). + Target(). + Return("") + // After a new cluster is discovered we expect these to be the new nodeConns expectedNewNodeConns := map[uint64]*connectionAndEndpoints{ - 3: { + 1: { conn: &connection{ clusterInfoClient: mockClusterInfoClient1111, aboutClient: mockAboutClient1111, @@ -436,7 +721,7 @@ func TestUpdateClusterConns_NewClusterID(t *testing.T) { }, }, }, - 4: { + 2: { conn: &connection{ clusterInfoClient: mockClusterInfoClient2222, aboutClient: mockAboutClient2222, @@ -452,11 +737,16 @@ func TestUpdateClusterConns_NewClusterID(t *testing.T) { }, } + grpcConn2. + EXPECT(). + Close(). + Return(nil) + cp.logger.Debug("Running updateClusterConns") cp.updateClusterConns(ctx) - assert.Equal(t, uint64(456), cp.clusterID) + assert.Equal(t, expectedClusterID, cp.clusterID) assert.Len(t, cp.nodeConns, 2) for k, v := range cp.nodeConns {