Skip to content

Commit

Permalink
fix unit tests, make tokenManager interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Jesse Schmidt committed Sep 19, 2024
1 parent befc3a5 commit 7cb8408
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 45 deletions.
8 changes: 7 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,18 @@ func NewClient(
logger = logger.WithGroup("avs")
logger.Info("creating new client")

var grpcToken tokenManager

if credentials != nil {
grpcToken = newGrpcJWTToken(credentials.username, credentials.password, logger)
}

connectionProvider, err := newConnectionProvider(
ctx,
seeds,
listenerName,
isLoadBalancer,
credentials,
grpcToken,
tlsConfig,
logger,
)
Expand Down
6 changes: 3 additions & 3 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3209,7 +3209,7 @@ func TestConnectedNodeEndpoint_Success(t *testing.T) {
defer ctrl.Finish()

mockConnProvider := NewMockconnProvider(ctrl)
mockGrpcConn := NewMockGrpcClientConn(ctrl)
mockGrpcConn := NewMockgrpcClientConn(ctrl)
mockConn := &connection{
grpcConn: mockGrpcConn,
}
Expand Down Expand Up @@ -3247,7 +3247,7 @@ func TestConnectedNodeEndpoint_FailedGetConn(t *testing.T) {
defer ctrl.Finish()

mockConnProvider := NewMockconnProvider(ctrl)
mockGrpcConn := NewMockGrpcClientConn(ctrl)
mockGrpcConn := NewMockgrpcClientConn(ctrl)
mockConn := &connection{
grpcConn: mockGrpcConn,
}
Expand Down Expand Up @@ -3276,7 +3276,7 @@ func TestConnectedNodeEndpoint_FailParsePort(t *testing.T) {
defer ctrl.Finish()

mockConnProvider := NewMockconnProvider(ctrl)
mockGrpcConn := NewMockGrpcClientConn(ctrl)
mockGrpcConn := NewMockgrpcClientConn(ctrl)
mockConn := &connection{
grpcConn: mockGrpcConn,
}
Expand Down
38 changes: 21 additions & 17 deletions connection_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,27 @@ import (

var errConnectionProviderClosed = errors.New("connection provider is closed")

type GrpcClientConn interface {
type grpcClientConn interface {
grpc.ClientConnInterface
Target() string
Close() error
}

type tokenManager interface {
RequireTransportSecurity() bool
ScheduleRefresh(func() (*connection, error))
RefreshToken(context.Context, grpcClientConn) error
UnaryInterceptor() grpc.UnaryClientInterceptor
StreamInterceptor() grpc.StreamClientInterceptor
Close()
}

// connection represents a gRPC client connection and all the clients (stubs)
// for the various AVS services. It's main purpose to remove the need to create
// multiple clients for the same connection. This follows the documented grpc
// best practice of reusing connections.
type connection struct {
grpcConn GrpcClientConn
grpcConn grpcClientConn
transactClient protos.TransactServiceClient
authClient protos.AuthServiceClient
userAdminClient protos.UserAdminServiceClient
Expand All @@ -44,7 +53,7 @@ type connection struct {
}

// newConnection creates a new connection instance.
func newConnection(conn GrpcClientConn) *connection {
func newConnection(conn grpcClientConn) *connection {
return &connection{
grpcConn: conn,
transactClient: protos.NewTransactServiceClient(conn),
Expand Down Expand Up @@ -85,7 +94,7 @@ type connectionProvider struct {
clusterID uint64
listenerName *string
isLoadBalancer bool
token *tokenManager
token tokenManager
stopTendChan chan struct{}
closed atomic.Bool
}
Expand All @@ -96,7 +105,7 @@ func newConnectionProvider(
seeds HostPortSlice,
listenerName *string,
isLoadBalancer bool,
credentials *UserPassCredentials,
token tokenManager,
tlsConfig *tls.Config,
logger *slog.Logger,
) (*connectionProvider, error) {
Expand All @@ -115,12 +124,7 @@ func newConnectionProvider(
return nil, errors.New(msg)
}

// Create a token manager if username and password are provided.
var token *tokenManager

if credentials != nil {
token = newJWTToken(credentials.username, credentials.password, logger)

if token != nil {
if token.RequireTransportSecurity() && tlsConfig == nil {
msg := "tlsConfig is required when username/password authentication"
logger.Error(msg)
Expand Down Expand Up @@ -324,7 +328,7 @@ func (cp *connectionProvider) connectToSeeds(ctx context.Context) error {
var authErr error

wg := sync.WaitGroup{}
seedGrpcConns := make(chan GrpcClientConn)
seedGrpcConns := make(chan grpcClientConn)
cp.seedConns = []*connection{}
tokenLock := sync.Mutex{} // Ensures only one thread attempts to update token at a time
tokenUpdated := false // Ensures token update only occurs once
Expand Down Expand Up @@ -444,11 +448,11 @@ func (cp *connectionProvider) checkAndSetClusterID(clusterID uint64) bool {
}

// getTendConns returns all the gRPC client connections for tend operations.
func (cp *connectionProvider) getTendConns() []GrpcClientConn {
func (cp *connectionProvider) getTendConns() []grpcClientConn {
cp.nodeConnsLock.RLock()
defer cp.nodeConnsLock.RUnlock()

conns := make([]GrpcClientConn, len(cp.seedConns)+len(cp.nodeConns))
conns := make([]grpcClientConn, len(cp.seedConns)+len(cp.nodeConns))
i := 0

for _, conn := range cp.seedConns {
Expand All @@ -474,7 +478,7 @@ func (cp *connectionProvider) getUpdatedEndpoints(ctx context.Context) map[uint6
for _, conn := range conns {
wg.Add(1)

go func(conn GrpcClientConn) {
go func(conn grpcClientConn) {
defer wg.Done()

logger := cp.logger.With(slog.String("host", conn.Target()))
Expand Down Expand Up @@ -685,7 +689,7 @@ func endpointToHostPort(endpoint *protos.ServerEndpoint) *HostPort {
// successful endpoint in endpoints.
func (cp *connectionProvider) createGrpcConnFromEndpoints(
endpoints *protos.ServerEndpointList,
) (GrpcClientConn, error) {
) (grpcClientConn, error) {
for _, endpoint := range endpoints.Endpoints {
if strings.ContainsRune(endpoint.Address, ':') {
continue // TODO: Add logging and support for IPv6
Expand All @@ -703,7 +707,7 @@ func (cp *connectionProvider) createGrpcConnFromEndpoints(

// createGrcpConn creates a gRPC client connection to a host. This handles adding
// credential and configuring tls.
func (cp *connectionProvider) createGrcpConn(hostPort *HostPort) (GrpcClientConn, error) {
func (cp *connectionProvider) createGrcpConn(hostPort *HostPort) (grpcClientConn, error) {
opts := []grpc.DialOption{}

if cp.tlsConfig == nil {
Expand Down
10 changes: 5 additions & 5 deletions connection_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ func TestNewConnectionProvider(t *testing.T) {
seeds := HostPortSlice{}
listenerName := "listener"
isLoadBalancer := false
credentials := &UserPassCredentials{
username: "admin",
password: "password",
}
// credentials := &UserPassCredentials{
// username: "admin",
// password: "password",
// }
tlsConfig := &tls.Config{}
var logger *slog.Logger

cp, err := newConnectionProvider(context.Background(), seeds, &listenerName, isLoadBalancer, credentials, tlsConfig, logger)
cp, err := newConnectionProvider(context.Background(), seeds, &listenerName, isLoadBalancer, nil, tlsConfig, logger)

assert.Nil(t, cp)
assert.Error(t, err, errors.New("seeds cannot be nil or empty"))
Expand Down
15 changes: 9 additions & 6 deletions protos/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1024,7 +1024,7 @@ func TestConvertFromValue(t *testing.T) {
},
},
},
expected: []any{float32(1), float32(2)},
expected: []float32{float32(1), float32(2)},
},
{
input: &Value{
Expand Down Expand Up @@ -1079,8 +1079,11 @@ func TestConvertToFields(t *testing.T) {

for _, tc := range testCases {
result, err := ConvertToFields(tc.input)
assert.Equal(t, len(tc.expected), len(result))

assert.Equal(t, tc.expected, result)
for i, _ := range tc.expected {
assert.EqualExportedValues(t, tc.expected[i], result[i])
}

if tc.expectedErr {
assert.Error(t, err)
Expand Down Expand Up @@ -1144,7 +1147,7 @@ func (*unknownVectorType) isVector_Data() {}
func TestConvertFromVector(t *testing.T) {
testCases := []struct {
input *Vector
expected []any
expected any
expectedErr error
}{
{
Expand All @@ -1155,7 +1158,7 @@ func TestConvertFromVector(t *testing.T) {
},
},
},
expected: []any{float32(1), float32(2)},
expected: []float32{float32(1), float32(2)},
},
{
input: &Vector{
Expand All @@ -1165,12 +1168,12 @@ func TestConvertFromVector(t *testing.T) {
},
},
},
expected: []any{true, false},
expected: []bool{true, false},
},
{
input: &Vector{Data: &unknownVectorType{}},
expected: nil,
expectedErr: fmt.Errorf("unsupported value type: *protos.unknownVectorType"),
expectedErr: fmt.Errorf("unsupported vector data type: *protos.unknownVectorType"),
},
}

Expand Down
26 changes: 13 additions & 13 deletions token_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ import (
"google.golang.org/grpc/metadata"
)

// tokenManager is responsible for managing authentication tokens and refreshing
// grpcTokenManager is responsible for managing authentication tokens and refreshing
// them when necessary.
//
//nolint:govet // We will favor readability over field alignment
type tokenManager struct {
type grpcTokenManager struct {
username string
password string
token atomic.Value
Expand All @@ -29,13 +29,13 @@ type tokenManager struct {
refreshScheduled bool
}

// newJWTToken creates a new tokenManager instance with the provided username, password, and logger.
func newJWTToken(username, password string, logger *slog.Logger) *tokenManager {
// newGrpcJWTToken creates a new tokenManager instance with the provided username, password, and logger.
func newGrpcJWTToken(username, password string, logger *slog.Logger) *grpcTokenManager {
logger.WithGroup("jwt")

logger.Debug("creating new token manager")

return &tokenManager{
return &grpcTokenManager{
username: username,
password: password,
logger: logger,
Expand All @@ -44,7 +44,7 @@ func newJWTToken(username, password string, logger *slog.Logger) *tokenManager {
}

// Close stops the scheduled token refresh and closes the token manager.
func (tm *tokenManager) Close() {
func (tm *grpcTokenManager) Close() {
if tm.refreshScheduled {
tm.logger.Debug("stopping scheduled token refresh")
tm.stopRefreshChan <- struct{}{}
Expand All @@ -55,14 +55,14 @@ func (tm *tokenManager) Close() {
}

// setRefreshTimeFromTTL sets the refresh time based on the provided time-to-live (TTL) duration.
func (tm *tokenManager) setRefreshTimeFromTTL(ttl time.Duration) {
func (tm *grpcTokenManager) setRefreshTimeFromTTL(ttl time.Duration) {
tm.refreshTime.Store(time.Now().Add(ttl))
}

// RefreshToken refreshes the authentication token using the provided gRPC client connection.
// It returns a boolean indicating if the token was successfully refreshed and
// an error if any. It is not thread safe.
func (tm *tokenManager) RefreshToken(ctx context.Context, conn grpc.ClientConnInterface) error {
func (tm *grpcTokenManager) RefreshToken(ctx context.Context, conn grpcClientConn) error {
// We only want one goroutine to refresh the token at a time
client := protos.NewAuthServiceClient(conn)
resp, err := client.Authenticate(ctx, &protos.AuthRequest{
Expand Down Expand Up @@ -120,7 +120,7 @@ func (tm *tokenManager) RefreshToken(ctx context.Context, conn grpc.ClientConnIn
// ScheduleRefresh schedules the token refresh using the provided function to
// get the gRPC client connection. This is not threadsafe. It should only be
// called once.
func (tm *tokenManager) ScheduleRefresh(getConn func() (*connection, error)) {
func (tm *grpcTokenManager) ScheduleRefresh(getConn func() (*connection, error)) {
if tm.refreshScheduled {
tm.logger.Warn("refresh already scheduled")
}
Expand Down Expand Up @@ -167,12 +167,12 @@ func (tm *tokenManager) ScheduleRefresh(getConn func() (*connection, error)) {
}

// RequireTransportSecurity returns true to indicate that transport security is required.
func (tm *tokenManager) RequireTransportSecurity() bool {
func (tm *grpcTokenManager) RequireTransportSecurity() bool {
return true
}

// UnaryInterceptor returns the grpc unary client interceptor that attaches the token to outgoing requests.
func (tm *tokenManager) UnaryInterceptor() grpc.UnaryClientInterceptor {
func (tm *grpcTokenManager) UnaryInterceptor() grpc.UnaryClientInterceptor {
return func(
ctx context.Context,
method string,
Expand All @@ -186,7 +186,7 @@ func (tm *tokenManager) UnaryInterceptor() grpc.UnaryClientInterceptor {
}

// StreamInterceptor returns the grpc stream client interceptor that attaches the token to outgoing requests.
func (tm *tokenManager) StreamInterceptor() grpc.StreamClientInterceptor {
func (tm *grpcTokenManager) StreamInterceptor() grpc.StreamClientInterceptor {
return func(
ctx context.Context,
desc *grpc.StreamDesc,
Expand All @@ -200,7 +200,7 @@ func (tm *tokenManager) StreamInterceptor() grpc.StreamClientInterceptor {
}

// attachToken attaches the authentication token to the outgoing context.
func (tm *tokenManager) attachToken(ctx context.Context) context.Context {
func (tm *grpcTokenManager) attachToken(ctx context.Context) context.Context {
rawToken := tm.token.Load()
if rawToken == nil {
return ctx
Expand Down

0 comments on commit 7cb8408

Please sign in to comment.