diff --git a/client.go b/client.go index e43b49d..ce15e36 100644 --- a/client.go +++ b/client.go @@ -72,12 +72,18 @@ func NewClient( logger = logger.WithGroup("avs") logger.Info("creating new client") + var grpcToken tokenManager + + if credentials != nil { + grpcToken = newGrpcJWTToken(credentials.username, credentials.password, logger) + } + connectionProvider, err := newConnectionProvider( ctx, seeds, listenerName, isLoadBalancer, - credentials, + grpcToken, tlsConfig, logger, ) diff --git a/client_test.go b/client_test.go index e9efbd2..99610b7 100644 --- a/client_test.go +++ b/client_test.go @@ -3209,7 +3209,7 @@ func TestConnectedNodeEndpoint_Success(t *testing.T) { defer ctrl.Finish() mockConnProvider := NewMockconnProvider(ctrl) - mockGrpcConn := NewMockGrpcClientConn(ctrl) + mockGrpcConn := NewMockgrpcClientConn(ctrl) mockConn := &connection{ grpcConn: mockGrpcConn, } @@ -3247,7 +3247,7 @@ func TestConnectedNodeEndpoint_FailedGetConn(t *testing.T) { defer ctrl.Finish() mockConnProvider := NewMockconnProvider(ctrl) - mockGrpcConn := NewMockGrpcClientConn(ctrl) + mockGrpcConn := NewMockgrpcClientConn(ctrl) mockConn := &connection{ grpcConn: mockGrpcConn, } @@ -3276,7 +3276,7 @@ func TestConnectedNodeEndpoint_FailParsePort(t *testing.T) { defer ctrl.Finish() mockConnProvider := NewMockconnProvider(ctrl) - mockGrpcConn := NewMockGrpcClientConn(ctrl) + mockGrpcConn := NewMockgrpcClientConn(ctrl) mockConn := &connection{ grpcConn: mockGrpcConn, } diff --git a/connection_provider.go b/connection_provider.go index eb66a1a..90c8b12 100644 --- a/connection_provider.go +++ b/connection_provider.go @@ -23,18 +23,27 @@ import ( var errConnectionProviderClosed = errors.New("connection provider is closed") -type GrpcClientConn interface { +type grpcClientConn interface { grpc.ClientConnInterface Target() string Close() error } +type tokenManager interface { + RequireTransportSecurity() bool + ScheduleRefresh(func() (*connection, error)) + RefreshToken(context.Context, grpcClientConn) error + UnaryInterceptor() grpc.UnaryClientInterceptor + StreamInterceptor() grpc.StreamClientInterceptor + Close() +} + // connection represents a gRPC client connection and all the clients (stubs) // for the various AVS services. It's main purpose to remove the need to create // multiple clients for the same connection. This follows the documented grpc // best practice of reusing connections. type connection struct { - grpcConn GrpcClientConn + grpcConn grpcClientConn transactClient protos.TransactServiceClient authClient protos.AuthServiceClient userAdminClient protos.UserAdminServiceClient @@ -44,7 +53,7 @@ type connection struct { } // newConnection creates a new connection instance. -func newConnection(conn GrpcClientConn) *connection { +func newConnection(conn grpcClientConn) *connection { return &connection{ grpcConn: conn, transactClient: protos.NewTransactServiceClient(conn), @@ -85,7 +94,7 @@ type connectionProvider struct { clusterID uint64 listenerName *string isLoadBalancer bool - token *tokenManager + token tokenManager stopTendChan chan struct{} closed atomic.Bool } @@ -96,7 +105,7 @@ func newConnectionProvider( seeds HostPortSlice, listenerName *string, isLoadBalancer bool, - credentials *UserPassCredentials, + token tokenManager, tlsConfig *tls.Config, logger *slog.Logger, ) (*connectionProvider, error) { @@ -115,12 +124,7 @@ func newConnectionProvider( return nil, errors.New(msg) } - // Create a token manager if username and password are provided. - var token *tokenManager - - if credentials != nil { - token = newJWTToken(credentials.username, credentials.password, logger) - + if token != nil { if token.RequireTransportSecurity() && tlsConfig == nil { msg := "tlsConfig is required when username/password authentication" logger.Error(msg) @@ -324,7 +328,7 @@ func (cp *connectionProvider) connectToSeeds(ctx context.Context) error { var authErr error wg := sync.WaitGroup{} - seedGrpcConns := make(chan GrpcClientConn) + seedGrpcConns := make(chan grpcClientConn) cp.seedConns = []*connection{} tokenLock := sync.Mutex{} // Ensures only one thread attempts to update token at a time tokenUpdated := false // Ensures token update only occurs once @@ -444,11 +448,11 @@ func (cp *connectionProvider) checkAndSetClusterID(clusterID uint64) bool { } // getTendConns returns all the gRPC client connections for tend operations. -func (cp *connectionProvider) getTendConns() []GrpcClientConn { +func (cp *connectionProvider) getTendConns() []grpcClientConn { cp.nodeConnsLock.RLock() defer cp.nodeConnsLock.RUnlock() - conns := make([]GrpcClientConn, len(cp.seedConns)+len(cp.nodeConns)) + conns := make([]grpcClientConn, len(cp.seedConns)+len(cp.nodeConns)) i := 0 for _, conn := range cp.seedConns { @@ -474,7 +478,7 @@ func (cp *connectionProvider) getUpdatedEndpoints(ctx context.Context) map[uint6 for _, conn := range conns { wg.Add(1) - go func(conn GrpcClientConn) { + go func(conn grpcClientConn) { defer wg.Done() logger := cp.logger.With(slog.String("host", conn.Target())) @@ -685,7 +689,7 @@ func endpointToHostPort(endpoint *protos.ServerEndpoint) *HostPort { // successful endpoint in endpoints. func (cp *connectionProvider) createGrpcConnFromEndpoints( endpoints *protos.ServerEndpointList, -) (GrpcClientConn, error) { +) (grpcClientConn, error) { for _, endpoint := range endpoints.Endpoints { if strings.ContainsRune(endpoint.Address, ':') { continue // TODO: Add logging and support for IPv6 @@ -703,7 +707,7 @@ func (cp *connectionProvider) createGrpcConnFromEndpoints( // createGrcpConn creates a gRPC client connection to a host. This handles adding // credential and configuring tls. -func (cp *connectionProvider) createGrcpConn(hostPort *HostPort) (GrpcClientConn, error) { +func (cp *connectionProvider) createGrcpConn(hostPort *HostPort) (grpcClientConn, error) { opts := []grpc.DialOption{} if cp.tlsConfig == nil { diff --git a/connection_provider_test.go b/connection_provider_test.go index fb5474d..a0b7d09 100644 --- a/connection_provider_test.go +++ b/connection_provider_test.go @@ -14,14 +14,14 @@ func TestNewConnectionProvider(t *testing.T) { seeds := HostPortSlice{} listenerName := "listener" isLoadBalancer := false - credentials := &UserPassCredentials{ - username: "admin", - password: "password", - } + // credentials := &UserPassCredentials{ + // username: "admin", + // password: "password", + // } tlsConfig := &tls.Config{} var logger *slog.Logger - cp, err := newConnectionProvider(context.Background(), seeds, &listenerName, isLoadBalancer, credentials, tlsConfig, logger) + cp, err := newConnectionProvider(context.Background(), seeds, &listenerName, isLoadBalancer, nil, tlsConfig, logger) assert.Nil(t, cp) assert.Error(t, err, errors.New("seeds cannot be nil or empty")) diff --git a/protos/utils_test.go b/protos/utils_test.go index bdabc65..6fec48b 100644 --- a/protos/utils_test.go +++ b/protos/utils_test.go @@ -1024,7 +1024,7 @@ func TestConvertFromValue(t *testing.T) { }, }, }, - expected: []any{float32(1), float32(2)}, + expected: []float32{float32(1), float32(2)}, }, { input: &Value{ @@ -1079,8 +1079,11 @@ func TestConvertToFields(t *testing.T) { for _, tc := range testCases { result, err := ConvertToFields(tc.input) + assert.Equal(t, len(tc.expected), len(result)) - assert.Equal(t, tc.expected, result) + for i, _ := range tc.expected { + assert.EqualExportedValues(t, tc.expected[i], result[i]) + } if tc.expectedErr { assert.Error(t, err) @@ -1144,7 +1147,7 @@ func (*unknownVectorType) isVector_Data() {} func TestConvertFromVector(t *testing.T) { testCases := []struct { input *Vector - expected []any + expected any expectedErr error }{ { @@ -1155,7 +1158,7 @@ func TestConvertFromVector(t *testing.T) { }, }, }, - expected: []any{float32(1), float32(2)}, + expected: []float32{float32(1), float32(2)}, }, { input: &Vector{ @@ -1165,12 +1168,12 @@ func TestConvertFromVector(t *testing.T) { }, }, }, - expected: []any{true, false}, + expected: []bool{true, false}, }, { input: &Vector{Data: &unknownVectorType{}}, expected: nil, - expectedErr: fmt.Errorf("unsupported value type: *protos.unknownVectorType"), + expectedErr: fmt.Errorf("unsupported vector data type: *protos.unknownVectorType"), }, } diff --git a/token_manager.go b/token_manager.go index 1947bcc..db1c653 100644 --- a/token_manager.go +++ b/token_manager.go @@ -15,11 +15,11 @@ import ( "google.golang.org/grpc/metadata" ) -// tokenManager is responsible for managing authentication tokens and refreshing +// grpcTokenManager is responsible for managing authentication tokens and refreshing // them when necessary. // //nolint:govet // We will favor readability over field alignment -type tokenManager struct { +type grpcTokenManager struct { username string password string token atomic.Value @@ -29,13 +29,13 @@ type tokenManager struct { refreshScheduled bool } -// newJWTToken creates a new tokenManager instance with the provided username, password, and logger. -func newJWTToken(username, password string, logger *slog.Logger) *tokenManager { +// newGrpcJWTToken creates a new tokenManager instance with the provided username, password, and logger. +func newGrpcJWTToken(username, password string, logger *slog.Logger) *grpcTokenManager { logger.WithGroup("jwt") logger.Debug("creating new token manager") - return &tokenManager{ + return &grpcTokenManager{ username: username, password: password, logger: logger, @@ -44,7 +44,7 @@ func newJWTToken(username, password string, logger *slog.Logger) *tokenManager { } // Close stops the scheduled token refresh and closes the token manager. -func (tm *tokenManager) Close() { +func (tm *grpcTokenManager) Close() { if tm.refreshScheduled { tm.logger.Debug("stopping scheduled token refresh") tm.stopRefreshChan <- struct{}{} @@ -55,14 +55,14 @@ func (tm *tokenManager) Close() { } // setRefreshTimeFromTTL sets the refresh time based on the provided time-to-live (TTL) duration. -func (tm *tokenManager) setRefreshTimeFromTTL(ttl time.Duration) { +func (tm *grpcTokenManager) setRefreshTimeFromTTL(ttl time.Duration) { tm.refreshTime.Store(time.Now().Add(ttl)) } // RefreshToken refreshes the authentication token using the provided gRPC client connection. // It returns a boolean indicating if the token was successfully refreshed and // an error if any. It is not thread safe. -func (tm *tokenManager) RefreshToken(ctx context.Context, conn grpc.ClientConnInterface) error { +func (tm *grpcTokenManager) RefreshToken(ctx context.Context, conn grpcClientConn) error { // We only want one goroutine to refresh the token at a time client := protos.NewAuthServiceClient(conn) resp, err := client.Authenticate(ctx, &protos.AuthRequest{ @@ -120,7 +120,7 @@ func (tm *tokenManager) RefreshToken(ctx context.Context, conn grpc.ClientConnIn // ScheduleRefresh schedules the token refresh using the provided function to // get the gRPC client connection. This is not threadsafe. It should only be // called once. -func (tm *tokenManager) ScheduleRefresh(getConn func() (*connection, error)) { +func (tm *grpcTokenManager) ScheduleRefresh(getConn func() (*connection, error)) { if tm.refreshScheduled { tm.logger.Warn("refresh already scheduled") } @@ -167,12 +167,12 @@ func (tm *tokenManager) ScheduleRefresh(getConn func() (*connection, error)) { } // RequireTransportSecurity returns true to indicate that transport security is required. -func (tm *tokenManager) RequireTransportSecurity() bool { +func (tm *grpcTokenManager) RequireTransportSecurity() bool { return true } // UnaryInterceptor returns the grpc unary client interceptor that attaches the token to outgoing requests. -func (tm *tokenManager) UnaryInterceptor() grpc.UnaryClientInterceptor { +func (tm *grpcTokenManager) UnaryInterceptor() grpc.UnaryClientInterceptor { return func( ctx context.Context, method string, @@ -186,7 +186,7 @@ func (tm *tokenManager) UnaryInterceptor() grpc.UnaryClientInterceptor { } // StreamInterceptor returns the grpc stream client interceptor that attaches the token to outgoing requests. -func (tm *tokenManager) StreamInterceptor() grpc.StreamClientInterceptor { +func (tm *grpcTokenManager) StreamInterceptor() grpc.StreamClientInterceptor { return func( ctx context.Context, desc *grpc.StreamDesc, @@ -200,7 +200,7 @@ func (tm *tokenManager) StreamInterceptor() grpc.StreamClientInterceptor { } // attachToken attaches the authentication token to the outgoing context. -func (tm *tokenManager) attachToken(ctx context.Context) context.Context { +func (tm *grpcTokenManager) attachToken(ctx context.Context) context.Context { rawToken := tm.token.Load() if rawToken == nil { return ctx