From a6c8b47442e225f0b4b85b33944bac37002e5897 Mon Sep 17 00:00:00 2001 From: Utku Ozdemir Date: Thu, 6 Jun 2024 16:54:39 +0200 Subject: [PATCH] fix: pass through the `talosctl -n` args if they cannot be resolved We were not correctly checking if the nodes passed via `talosctl --nodes` were resolved before replacing the target in the GRPC metadata with the resolved ones. We were handling a single node in the metadata correctly, but not doing the same handing for multiple nodes. This PR fixes that. This PR simplifies that logic. Signed-off-by: Utku Ozdemir --- internal/backend/grpc/router/omni_backend.go | 14 +++- internal/backend/grpc/router/resolve.go | 61 ++++++++------ internal/backend/grpc/router/talos_backend.go | 30 +++++-- .../backend/grpc/router/talos_backend_test.go | 84 +++++++++++++++++++ 4 files changed, 150 insertions(+), 39 deletions(-) diff --git a/internal/backend/grpc/router/omni_backend.go b/internal/backend/grpc/router/omni_backend.go index 7d4525fd..05cee26c 100644 --- a/internal/backend/grpc/router/omni_backend.go +++ b/internal/backend/grpc/router/omni_backend.go @@ -48,10 +48,16 @@ func (l *OmniBackend) GetConnection(ctx context.Context, _ string) (context.Cont // Use a new header to avoid signature mismatch. resolved := resolveNodes(l.nodeResolver, md) - if resolved.node.Address != "" { - md.Set(ResolvedNodesHeaderKey, resolved.node.Address) - } else if len(resolved.nodes) > 0 { - md.Set(ResolvedNodesHeaderKey, xslices.Map(resolved.nodes, func(info dns.Info) string { + var allNodes []dns.Info + + if resolved.nodeOk { + allNodes = append(allNodes, resolved.node) + } + + allNodes = append(allNodes, resolved.nodes...) + + if len(allNodes) > 0 { + md.Set(ResolvedNodesHeaderKey, xslices.Map(allNodes, func(info dns.Info) string { return info.Address })...) } diff --git a/internal/backend/grpc/router/resolve.go b/internal/backend/grpc/router/resolve.go index db037c08..58531271 100644 --- a/internal/backend/grpc/router/resolve.go +++ b/internal/backend/grpc/router/resolve.go @@ -9,12 +9,16 @@ import ( "strings" "github.com/siderolabs/gen/xslices" - "github.com/siderolabs/go-api-signature/pkg/message" "google.golang.org/grpc/metadata" "github.com/siderolabs/omni/internal/backend/dns" ) +const ( + nodeHeaderKey = "node" + nodesHeaderKey = "nodes" +) + // NodeResolver resolves a given cluster and a node name to an IP address. type NodeResolver interface { Resolve(cluster, node string) dns.Info @@ -23,49 +27,54 @@ type NodeResolver interface { type resolvedNodeInfo struct { node dns.Info nodes []dns.Info + + nodeOk bool } func resolveNodes(dnsService NodeResolver, md metadata.MD) resolvedNodeInfo { - nodesVal := md.Get(message.NodesHeaderKey) + var ( + node string + nodes []string - cluster := getClusterName(md) + nodeOK bool + ) - nodes := make([]string, 0, len(nodesVal)*2) - for _, node := range nodesVal { - nodes = append(nodes, strings.Split(node, ",")...) - } + if nodeVal := md.Get(nodeHeaderKey); len(nodeVal) > 0 { + nodeOK = true - node := "" - if nodeVal := md.Get("node"); len(nodeVal) > 0 { node = nodeVal[0] } - if cluster == "" { - return resolvedNodeInfo{ - nodes: xslices.Map(nodes, func(n string) dns.Info { - return dns.Info{Address: n} - }), - node: dns.Info{Address: node}, + if nodesVal := md.Get(nodesHeaderKey); len(nodesVal) > 0 { + nodes = make([]string, 0, len(nodesVal)*2) + for _, n := range nodesVal { + nodes = append(nodes, strings.Split(n, ",")...) } } + cluster := getClusterName(md) + resolveNode := func(val string) dns.Info { - if val == "" { - return dns.Info{} - } + var resolved dns.Info - return dnsService.Resolve(cluster, val) - } + if cluster != "" && val != "" { + resolved = dnsService.Resolve(cluster, val) + } - resolvedNode := resolveNode(node) + if resolved.Address == "" { + return dns.Info{ + Cluster: cluster, + Name: val, + Address: val, + } + } - resolvedNodes := make([]dns.Info, 0, len(nodes)) - for _, n := range nodesVal { - resolvedNodes = append(resolvedNodes, resolveNode(n)) + return resolved } return resolvedNodeInfo{ - nodes: resolvedNodes, - node: resolvedNode, + node: resolveNode(node), + nodes: xslices.Map(nodes, resolveNode), + nodeOk: nodeOK, } } diff --git a/internal/backend/grpc/router/talos_backend.go b/internal/backend/grpc/router/talos_backend.go index daa129ba..a48c63f8 100644 --- a/internal/backend/grpc/router/talos_backend.go +++ b/internal/backend/grpc/router/talos_backend.go @@ -66,7 +66,10 @@ func (l *TalosBackend) String() string { // GetConnection returns a grpc connection to the backend. func (l *TalosBackend) GetConnection(ctx context.Context, fullMethodName string) (context.Context, *grpc.ClientConn, error) { - md, _ := metadata.FromIncomingContext(ctx) + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + md = metadata.New(nil) + } // we can't use regular gRPC server interceptors here, as proxy interface is a bit different @@ -108,16 +111,21 @@ func (l *TalosBackend) GetConnection(ctx context.Context, fullMethodName string) // overwrite the node headers with the resolved ones resolved := resolveNodes(l.nodeResolver, md) - if resolved.node.Address != "" { + + if resolved.nodeOk { md = md.Copy() - setHeaderData(ctx, md, "node", resolved.node.Address) - } else if len(resolved.nodes) > 0 { + setHeaderData(ctx, md, nodeHeaderKey, resolved.node.Address) + } + + if len(resolved.nodes) > 0 { md = md.Copy() - setHeaderData(ctx, md, "nodes", xslices.Map(resolved.nodes, func(info dns.Info) string { + addresses := xslices.Map(resolved.nodes, func(info dns.Info) string { return info.Address - })...) + }) + + setHeaderData(ctx, md, nodesHeaderKey, addresses...) } l.setRoleHeaders(ctx, md, fullMethodName, resolved, hasModifyAccess) @@ -127,14 +135,14 @@ func (l *TalosBackend) GetConnection(ctx context.Context, fullMethodName string) return outCtx, l.conn, nil } -func (l *TalosBackend) setRoleHeaders(ctx context.Context, md metadata.MD, fullMethodName string, resolvedInfo resolvedNodeInfo, hasModifyAccess bool) { +func (l *TalosBackend) setRoleHeaders(ctx context.Context, md metadata.MD, fullMethodName string, info resolvedNodeInfo, hasModifyAccess bool) { if !hasModifyAccess { setHeaderData(ctx, md, constants.APIAuthzRoleMetadataKey, talosrole.MakeSet(talosrole.Reader).Strings()...) return } - minTalosVersion := l.minTalosVersion(resolvedInfo) + minTalosVersion := l.minTalosVersion(info) // min Talos version is >= 1.4.0, we can use Operator role if minTalosVersion != nil && minTalosVersion.GTE(semver.MustParse("1.4.0")) { @@ -152,7 +160,11 @@ func (l *TalosBackend) setRoleHeaders(ctx context.Context, md metadata.MD, fullM } func (l *TalosBackend) minTalosVersion(info resolvedNodeInfo) *semver.Version { - ver := takePtr(semver.ParseTolerant(info.node.TalosVersion)) + var ver *semver.Version + + if info.nodeOk { + ver = takePtr(semver.ParseTolerant(info.node.TalosVersion)) + } for _, node := range info.nodes { nodeVer := takePtr(semver.ParseTolerant(node.TalosVersion)) diff --git a/internal/backend/grpc/router/talos_backend_test.go b/internal/backend/grpc/router/talos_backend_test.go index 516e0de0..7bb9d268 100644 --- a/internal/backend/grpc/router/talos_backend_test.go +++ b/internal/backend/grpc/router/talos_backend_test.go @@ -65,6 +65,82 @@ func TestTalosBackendRoles(t *testing.T) { require.Equal(t, "talos-machine", hostnameResult.Messages[0].Hostname) } +func TestNodeResolution(t *testing.T) { + resolver := &mockResolver{ + db: map[string]map[string]dns.Info{ + "cluster-1": { + "node-1": {Address: "1.1.1.1"}, + "node-2": {Address: "2.2.2.2"}, + }, + }, + } + + noOpVerifier := func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + return handler(ctx, req) + } + talosBackend := router.NewTalosBackend("test-backend", resolver, nil, false, noOpVerifier) + + testCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + t.Cleanup(cancel) + + t.Run(`resolvable "node"`, func(t *testing.T) { + ctx := metadata.NewIncomingContext(testCtx, metadata.Pairs("cluster", "cluster-1", "node", "node-1")) + + newCtx, _, err := talosBackend.GetConnection(ctx, "some-method") + require.NoError(t, err) + + incomingContext, ok := metadata.FromOutgoingContext(newCtx) + require.True(t, ok, "metadata not found") + + require.Len(t, incomingContext["node"], 1) + require.Equal(t, "1.1.1.1", incomingContext["node"][0]) + }) + + t.Run(`resolvable "nodes"`, func(t *testing.T) { + ctx := metadata.NewIncomingContext(testCtx, metadata.Pairs("cluster", "cluster-1", "nodes", "node-1,node-2")) + + newCtx, _, err := talosBackend.GetConnection(ctx, "some-method") + require.NoError(t, err) + + incomingContext, ok := metadata.FromOutgoingContext(newCtx) + require.True(t, ok, "metadata not found") + + require.Equal(t, []string{"1.1.1.1", "2.2.2.2"}, incomingContext["nodes"]) + }) + + t.Run(`both "node" and "nodes" set`, func(t *testing.T) { + ctx := metadata.NewIncomingContext(testCtx, metadata.Pairs("cluster", "cluster-1", "node", "node-1", "nodes", "node-1,node-2")) + + newCtx, _, err := talosBackend.GetConnection(ctx, "some-method") + require.NoError(t, err) + + incomingContext, ok := metadata.FromOutgoingContext(newCtx) + require.True(t, ok, "metadata not found") + + require.Len(t, incomingContext["node"], 1) + require.Equal(t, "1.1.1.1", incomingContext["node"][0]) + require.Equal(t, []string{"1.1.1.1", "2.2.2.2"}, incomingContext["nodes"]) + }) + + t.Run(`fallback when unresolved`, func(t *testing.T) { + ctx := metadata.NewIncomingContext(testCtx, metadata.Pairs("cluster", "cluster-1", "node", "node-3", "nodes", "node-0,node-1,node-2,node-3")) + + newCtx, _, err := talosBackend.GetConnection(ctx, "some-method") + require.NoError(t, err) + + incomingContext, ok := metadata.FromOutgoingContext(newCtx) + require.True(t, ok, "metadata not found") + + require.Len(t, incomingContext["node"], 1) + + // "node" is unresolved, so it should be kept as-is + require.Equal(t, "node-3", incomingContext["node"][0]) + + // some of the "nodes" are unresolved, so only they should be kept as-is + require.Equal(t, []string{"node-0", "1.1.1.1", "2.2.2.2", "node-3"}, incomingContext["nodes"]) + }) +} + func makeGRPCProxy(ctx context.Context, endpoint, serverEndpoint string) (func() error, error) { grpcProxyServer := router.NewServer(&testDirector{serverEndpoint: serverEndpoint}) @@ -208,3 +284,11 @@ func recvContext[T any](ctx context.Context, ch <-chan T) (T, bool) { return v, true } } + +type mockResolver struct { + db map[string]map[string]dns.Info +} + +func (m *mockResolver) Resolve(cluster, node string) dns.Info { + return m.db[cluster][node] +}