From 7eb745870053d983d9cbe1a6051b1ac098ffbdc3 Mon Sep 17 00:00:00 2001 From: Karol Kokoszka Date: Wed, 14 Feb 2024 17:37:17 +0100 Subject: [PATCH] fix(healtcheck): respect force-tls-disabled and force-non-ssl-session-port cluster properties --- pkg/cmd/scylla-manager/server.go | 1 + pkg/service/cluster/service.go | 3 + pkg/service/healthcheck/service.go | 63 +++++++++++++------ .../healthcheck/service_integration_test.go | 41 ++++++++---- 4 files changed, 77 insertions(+), 31 deletions(-) diff --git a/pkg/cmd/scylla-manager/server.go b/pkg/cmd/scylla-manager/server.go index 447ce28714..ee2e2001de 100644 --- a/pkg/cmd/scylla-manager/server.go +++ b/pkg/cmd/scylla-manager/server.go @@ -81,6 +81,7 @@ func (s *server) makeServices() error { s.config.Healthcheck, s.clusterSvc.Client, secretsStore, + s.clusterSvc.GetClusterByID, s.logger.Named("healthcheck"), ) if err != nil { diff --git a/pkg/service/cluster/service.go b/pkg/service/cluster/service.go index 3e5aed0b0c..8182ad501b 100644 --- a/pkg/service/cluster/service.go +++ b/pkg/service/cluster/service.go @@ -25,6 +25,9 @@ import ( "go.uber.org/multierr" ) +// ProviderFunc defines the function that will be used by other services to get current cluster data. +type ProviderFunc func(ctx context.Context, id uuid.UUID) (*Cluster, error) + // ChangeType specifies type on Change. type ChangeType int8 diff --git a/pkg/service/healthcheck/service.go b/pkg/service/healthcheck/service.go index 89f0a5d943..4427704832 100644 --- a/pkg/service/healthcheck/service.go +++ b/pkg/service/healthcheck/service.go @@ -12,6 +12,7 @@ import ( "github.com/pkg/errors" "github.com/scylladb/go-log" + "github.com/scylladb/scylla-manager/v3/pkg/service/cluster" "golang.org/x/sync/errgroup" "github.com/scylladb/scylla-manager/v3/pkg/ping" @@ -51,17 +52,23 @@ func (pt pingType) String() string { return "unknown" } +type tlsConfigWithAddress struct { + *tls.Config + Address string +} + type nodeInfo struct { *scyllaclient.NodeInfo - TLSConfig map[pingType]*tls.Config + TLSConfig map[pingType]*tlsConfigWithAddress Expires time.Time } // Service manages health checks. type Service struct { - config Config - scyllaClient scyllaclient.ProviderFunc - secretsStore store.Store + config Config + scyllaClient scyllaclient.ProviderFunc + secretsStore store.Store + clusterProvider cluster.ProviderFunc cacheMu sync.Mutex // fields below are protected by cacheMu @@ -70,17 +77,20 @@ type Service struct { logger log.Logger } -func NewService(config Config, scyllaClient scyllaclient.ProviderFunc, secretsStore store.Store, logger log.Logger) (*Service, error) { +func NewService(config Config, scyllaClient scyllaclient.ProviderFunc, secretsStore store.Store, + clusterProvider cluster.ProviderFunc, logger log.Logger, +) (*Service, error) { if scyllaClient == nil { return nil, errors.New("invalid scylla provider") } return &Service{ - config: config, - scyllaClient: scyllaClient, - secretsStore: secretsStore, - nodeInfoCache: make(map[clusterIDHost]nodeInfo), - logger: logger, + config: config, + scyllaClient: scyllaClient, + secretsStore: secretsStore, + clusterProvider: clusterProvider, + nodeInfoCache: make(map[clusterIDHost]nodeInfo), + logger: logger, }, nil } @@ -354,7 +364,7 @@ func (s *Service) pingCQL(ctx context.Context, clusterID uuid.UUID, host string, tlsConfig := ni.tlsConfig(cqlPing) if tlsConfig != nil { - config.Addr = ni.CQLSSLAddr(host) + config.Addr = tlsConfig.Address config.TLSConfig = tlsConfig.Clone() } @@ -399,6 +409,12 @@ func (s *Service) nodeInfo(ctx context.Context, clusterID uuid.UUID, host string if ni, ok := s.nodeInfoCache[key]; ok && now.Before(ni.Expires) { return ni, nil } + + c, err := s.clusterProvider(ctx, clusterID) + if err != nil { + return nodeInfo{}, err + } + client, err := s.scyllaClient(ctx, clusterID) if err != nil { return nodeInfo{}, errors.Wrap(err, "create scylla client") @@ -409,13 +425,20 @@ func (s *Service) nodeInfo(ctx context.Context, clusterID uuid.UUID, host string return nodeInfo{}, errors.Wrap(err, "fetch node info") } - ni.TLSConfig = make(map[pingType]*tls.Config, 2) + ni.TLSConfig = make(map[pingType]*tlsConfigWithAddress, 2) for _, p := range []pingType{alternatorPing, cqlPing} { var tlsEnabled, clientCertAuth bool + var address string if p == cqlPing { + address = ni.CQLAddr(host) tlsEnabled, clientCertAuth = ni.CQLTLSEnabled() + tlsEnabled = tlsEnabled && !c.ForceTLSDisabled + if tlsEnabled && !c.ForceNonSSLSessionPort { + address = ni.CQLSSLAddr(host) + } } else if p == alternatorPing { tlsEnabled, clientCertAuth = ni.AlternatorTLSEnabled() + address = ni.AlternatorAddr(host) } if tlsEnabled { tlsConfig, err := s.tlsConfig(clusterID, clientCertAuth) @@ -423,11 +446,11 @@ func (s *Service) nodeInfo(ctx context.Context, clusterID uuid.UUID, host string return ni, errors.Wrap(err, "fetch TLS config") } if clientCertAuth && errors.Is(err, service.ErrNotFound) { - s.logger.Info(ctx, "Client encryption is enabled, but Cluster wasn't registered with certificate in Scylla Manager, falling back to nonSSL port.", - "cluster_id", clusterID, - ) - } else { - ni.TLSConfig[p] = tlsConfig + return nodeInfo{}, errors.Wrap(err, "client encryption is enabled, but certificate is missing") + } + ni.TLSConfig[p] = &tlsConfigWithAddress{ + Config: tlsConfig, + Address: address, } } } @@ -439,7 +462,7 @@ func (s *Service) nodeInfo(ctx context.Context, clusterID uuid.UUID, host string } func (s *Service) tlsConfig(clusterID uuid.UUID, clientCertAuth bool) (*tls.Config, error) { - cfg := &tls.Config{ + cfg := tls.Config{ InsecureSkipVerify: true, } @@ -457,7 +480,7 @@ func (s *Service) tlsConfig(clusterID uuid.UUID, clientCertAuth bool) (*tls.Conf cfg.Certificates = []tls.Certificate{keyPair} } - return cfg, nil + return &cfg, nil } func (s *Service) cqlCreds(ctx context.Context, clusterID uuid.UUID) *secrets.CQLCreds { @@ -486,7 +509,7 @@ func (s *Service) InvalidateCache(clusterID uuid.UUID) { s.cacheMu.Unlock() } -func (ni nodeInfo) tlsConfig(pt pingType) *tls.Config { +func (ni nodeInfo) tlsConfig(pt pingType) *tlsConfigWithAddress { return ni.TLSConfig[pt] } diff --git a/pkg/service/healthcheck/service_integration_test.go b/pkg/service/healthcheck/service_integration_test.go index 3a8bc650ef..bc6f200b02 100644 --- a/pkg/service/healthcheck/service_integration_test.go +++ b/pkg/service/healthcheck/service_integration_test.go @@ -18,11 +18,12 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/scylladb/go-log" + "github.com/scylladb/scylla-manager/v3/pkg/metrics" + "github.com/scylladb/scylla-manager/v3/pkg/service/cluster" "go.uber.org/zap/zapcore" "github.com/scylladb/scylla-manager/v3/pkg/schema/table" "github.com/scylladb/scylla-manager/v3/pkg/scyllaclient" - "github.com/scylladb/scylla-manager/v3/pkg/secrets" "github.com/scylladb/scylla-manager/v3/pkg/store" . "github.com/scylladb/scylla-manager/v3/pkg/testutils" . "github.com/scylladb/scylla-manager/v3/pkg/testutils/db" @@ -40,10 +41,22 @@ func TestStatusIntegration(t *testing.T) { session := CreateScyllaManagerDBSession(t) defer session.Close() - clusterID := uuid.MustRandom() - s := store.NewTableStore(session, table.Secrets) - testStatusIntegration(t, clusterID, s) + clusterSvc, err := cluster.NewService(session, metrics.NewClusterMetrics(), s, scyllaclient.DefaultTimeoutConfig(), log.NewDevelopment()) + if err != nil { + t.Fatal(err) + } + + c := &cluster.Cluster{ + Host: "192.168.200.11", + AuthToken: "token", + } + err = clusterSvc.PutCluster(context.Background(), c) + if err != nil { + t.Fatal(err) + } + + testStatusIntegration(t, c.ID, clusterSvc.GetClusterByID, s) } func TestStatusWithCQLCredentialsIntegration(t *testing.T) { @@ -55,21 +68,26 @@ func TestStatusWithCQLCredentialsIntegration(t *testing.T) { session := CreateScyllaManagerDBSession(t) defer session.Close() - clusterID := uuid.MustRandom() - s := store.NewTableStore(session, table.Secrets) - if err := s.Put(&secrets.CQLCreds{ - ClusterID: clusterID, + clusterSvc, err := cluster.NewService(session, metrics.NewClusterMetrics(), s, scyllaclient.DefaultTimeoutConfig(), log.NewDevelopment()) + if err != nil { + t.Fatal(err) + } + c := &cluster.Cluster{ + Host: "192.168.200.11", + AuthToken: "token", Username: username, Password: password, - }); err != nil { + } + err = clusterSvc.PutCluster(context.Background(), c) + if err != nil { t.Fatal(err) } - testStatusIntegration(t, clusterID, s) + testStatusIntegration(t, c.ID, clusterSvc.GetClusterByID, s) } -func testStatusIntegration(t *testing.T, clusterID uuid.UUID, secretsStore store.Store) { +func testStatusIntegration(t *testing.T, clusterID uuid.UUID, clusterProvider cluster.ProviderFunc, secretsStore store.Store) { logger := log.NewDevelopmentWithLevel(zapcore.InfoLevel).Named("healthcheck") // Tests here do not test the dynamic t/o functionality @@ -97,6 +115,7 @@ func testStatusIntegration(t *testing.T, clusterID uuid.UUID, secretsStore store return scyllaclient.NewClient(sc, logger.Named("scylla")) }, secretsStore, + clusterProvider, logger, ) if err != nil {