From 72675d21bdd8b06ccdf24e6bd0c530979b3ab238 Mon Sep 17 00:00:00 2001 From: Aditya Thebe Date: Thu, 4 Jan 2024 09:44:18 +0545 Subject: [PATCH] feat: add tests for changes & analyses syncer * migrated errors/http statuses from mission-control * migrated upstream handlers from mission-control --- api/errors.go | 87 +++++++++++++++++++++++++ api/http.go | 48 ++++++++++++++ query/agent.go | 53 ++++++++++++++++ tests/upstream_test.go | 137 ++++++++++++++++++++++++++++++++++++++++ upstream/commands.go | 124 ++++++++++++++++++++++++++++++++++++ upstream/controllers.go | 128 +++++++++++++++++++++++++++++++++++++ 6 files changed, 577 insertions(+) create mode 100644 api/errors.go create mode 100644 api/http.go create mode 100644 query/agent.go create mode 100644 tests/upstream_test.go create mode 100644 upstream/commands.go create mode 100644 upstream/controllers.go diff --git a/api/errors.go b/api/errors.go new file mode 100644 index 00000000..c3699501 --- /dev/null +++ b/api/errors.go @@ -0,0 +1,87 @@ +package api + +import ( + "errors" + "fmt" +) + +// Application error codes. +// +// These are meant to be generic and they map well to HTTP error codes. +const ( + ECONFLICT = "conflict" + EFORBIDDEN = "forbidden" + EINTERNAL = "internal" + EINVALID = "invalid" + ENOTFOUND = "not_found" + ENOTIMPLEMENTED = "not_implemented" + EUNAUTHORIZED = "unauthorized" +) + +// Error represents an application-specific error. +type Error struct { + // Machine-readable error code. + Code string + + // Human-readable error message. + Message string + + // DebugInfo contains low-level internal error details that should only be logged. + // End-users should never see this. + DebugInfo string +} + +// Error implements the error interface. Not used by the application otherwise. +func (e *Error) Error() string { + return fmt.Sprintf("error: code=%s message=%s", e.Code, e.Message) +} + +// WithDebugInfo wraps an application error with a debug message. +func (e *Error) WithDebugInfo(msg string, args ...any) *Error { + e.DebugInfo = fmt.Sprintf(msg, args...) + return e +} + +// ErrorCode unwraps an application error and returns its code. +// Non-application errors always return EINTERNAL. +func ErrorCode(err error) string { + var e *Error + if err == nil { + return "" + } else if errors.As(err, &e) { + return e.Code + } + return EINTERNAL +} + +// ErrorMessage unwraps an application error and returns its message. +// Non-application errors always return "Internal error". +func ErrorMessage(err error) string { + var e *Error + if err == nil { + return "" + } else if errors.As(err, &e) { + return e.Message + } + return "Internal error." +} + +// ErrorDebugInfo unwraps an application error and returns its debug message. +func ErrorDebugInfo(err error) string { + var e *Error + if err == nil { + return "" + } else if errors.As(err, &e) { + return e.DebugInfo + } + + return err.Error() +} + +// Errorf is a helper function to return an Error with a given code and formatted message. +func Errorf(code string, format string, args ...any) *Error { + return &Error{ + Code: code, + Message: fmt.Sprintf(format, args...), + } +} diff --git a/api/http.go b/api/http.go new file mode 100644 index 00000000..e1caa925 --- /dev/null +++ b/api/http.go @@ -0,0 +1,48 @@ +package api + +import ( + "net/http" + + "github.com/flanksource/commons/logger" + "github.com/labstack/echo/v4" +) + +type HTTPError struct { + Error string `json:"error"` + Message string `json:"message,omitempty"` +} + +type HTTPSuccess struct { + Message string `json:"message"` + Payload any `json:"payload,omitempty"` +} + +func WriteError(c echo.Context, err error) error { + code, message := ErrorCode(err), ErrorMessage(err) + + if debugInfo := ErrorDebugInfo(err); debugInfo != "" { + logger.WithValues("code", code, "error", message).Errorf(debugInfo) + } + + return c.JSON(ErrorStatusCode(code), &HTTPError{Error: message}) +} + +// ErrorStatusCode returns the associated HTTP status code for an application error code. +func ErrorStatusCode(code string) int { + // lookup of application error codes to HTTP status codes. + var codes = map[string]int{ + ECONFLICT: http.StatusConflict, + EINVALID: http.StatusBadRequest, + ENOTFOUND: http.StatusNotFound, + EFORBIDDEN: http.StatusForbidden, + ENOTIMPLEMENTED: http.StatusNotImplemented, + EUNAUTHORIZED: http.StatusUnauthorized, + EINTERNAL: http.StatusInternalServerError, + } + + if v, ok := codes[code]; ok { + return v + } + + return http.StatusInternalServerError +} diff --git a/query/agent.go b/query/agent.go new file mode 100644 index 00000000..d20b3594 --- /dev/null +++ b/query/agent.go @@ -0,0 +1,53 @@ +package query + +import ( + "errors" + "fmt" + "strings" + + "github.com/flanksource/duty/context" + "github.com/flanksource/duty/models" + "github.com/google/uuid" + "gorm.io/gorm" +) + +func FindAgent(ctx context.Context, name string) (*models.Agent, error) { + var agent models.Agent + err := ctx.DB().Where("name = ?", name).First(&agent).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + + return nil, err + } + + return &agent, nil +} + +func GetAllResourceIDsOfAgent(ctx context.Context, table, from string, size int, agentID uuid.UUID) ([]string, error) { + var response []string + var err error + + switch table { + case "check_statuses": + query := ` + SELECT (check_id::TEXT || ',' || time::TEXT) + FROM check_statuses + LEFT JOIN checks ON checks.id = check_statuses.check_id + WHERE checks.agent_id = ? AND (check_statuses.check_id::TEXT, check_statuses.time::TEXT) > (?, ?) + ORDER BY check_statuses.check_id, check_statuses.time + LIMIT ?` + parts := strings.Split(from, ",") + if len(parts) != 2 { + return nil, fmt.Errorf("%s is not a valid next cursor. It must consist of check_id and time separated by a comma", from) + } + + err = ctx.DB().Raw(query, agentID, parts[0], parts[1], size).Scan(&response).Error + default: + query := fmt.Sprintf("SELECT id FROM %s WHERE agent_id = ? AND id::TEXT > ? ORDER BY id LIMIT ?", table) + err = ctx.DB().Raw(query, agentID, from, size).Scan(&response).Error + } + + return response, err +} diff --git a/tests/upstream_test.go b/tests/upstream_test.go new file mode 100644 index 00000000..49e6befa --- /dev/null +++ b/tests/upstream_test.go @@ -0,0 +1,137 @@ +package tests + +import ( + "fmt" + "time" + + "github.com/labstack/echo/v4" + ginkgo "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/patrickmn/go-cache" + + "github.com/flanksource/duty/context" + "github.com/flanksource/duty/models" + "github.com/flanksource/duty/tests/setup" + "github.com/flanksource/duty/upstream" +) + +var _ = ginkgo.Describe("Config Changes & Analyses sync test", ginkgo.Ordered, func() { + var upstreamCtx *context.Context + var echoCloser, drop func() + var upstreamConf upstream.UpstreamConfig + const agentName = "my-agent" + + ginkgo.It("prepare upstream database", func() { + var err error + upstreamCtx, drop, err = setup.NewDB(DefaultContext, "upstream") + Expect(err).ToNot(HaveOccurred()) + + var changes int + err = upstreamCtx.DB().Select("COUNT(*)").Model(&models.ConfigChange{}).Scan(&changes).Error + Expect(err).ToNot(HaveOccurred()) + Expect(changes).To(Equal(0)) + + var analyses int + err = upstreamCtx.DB().Select("COUNT(*)").Model(&models.ConfigAnalysis{}).Scan(&analyses).Error + Expect(err).ToNot(HaveOccurred()) + Expect(analyses).To(Equal(0)) + + agent := models.Agent{Name: agentName} + err = upstreamCtx.DB().Create(&agent).Error + Expect(err).ToNot(HaveOccurred()) + }) + + ginkgo.It("should setup upstream echo server", func() { + var port int + e := echo.New() + e.Use(func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + c.SetRequest(c.Request().WithContext(upstreamCtx.Wrap(c.Request().Context()))) + return next(c) + } + }) + + e.POST("/upstream/push", upstream.PushHandler(cache.New(time.Hour, time.Hour))) + e.GET("/upstream/pull/:agent_name", upstream.PullHandler([]string{"config_scrapers", "config_items"})) + e.GET("/upstream/status/:agent_name", upstream.StatusHandler([]string{"config_scrapers", "config_items"})) + + port, echoCloser = setup.RunEcho(e) + + upstreamConf = upstream.UpstreamConfig{ + Host: fmt.Sprintf("http://localhost:%d", port), + AgentName: agentName, + } + }) + + ginkgo.It("should push config items first to satisfy foregin keys for changes & analyses", func() { + reconciler := upstream.NewUpstreamReconciler(upstreamConf, 100) + + count, err := reconciler.Sync(DefaultContext, "config_items") + Expect(err).To(BeNil()) + Expect(count).To(Not(BeZero())) + }) + + ginkgo.It("should sync config_changes to upstream", func() { + { + var pushed int + err := DefaultContext.DB().Select("COUNT(*)").Where("is_pushed = true").Model(&models.ConfigChange{}).Scan(&pushed).Error + Expect(err).ToNot(HaveOccurred()) + Expect(pushed).To(BeZero()) + } + + var changes int + err := upstreamCtx.DB().Select("COUNT(*)").Model(&models.ConfigChange{}).Scan(&changes).Error + Expect(err).ToNot(HaveOccurred()) + Expect(changes).To(BeZero()) + + count, err := upstream.SyncConfigChanges(DefaultContext, upstreamConf, 10) + Expect(err).ToNot(HaveOccurred()) + + err = upstreamCtx.DB().Select("COUNT(*)").Model(&models.ConfigChange{}).Scan(&changes).Error + Expect(err).ToNot(HaveOccurred()) + Expect(changes).To(Equal(count)) + + { + var pending int + err := DefaultContext.DB().Select("COUNT(*)").Where("is_pushed = false").Model(&models.ConfigChange{}).Scan(&pending).Error + Expect(err).ToNot(HaveOccurred()) + Expect(pending).To(BeZero()) + } + }) + + ginkgo.It("should sync config_analyses to upstream", func() { + { + var pushed int + err := DefaultContext.DB().Select("COUNT(*)").Where("is_pushed = true").Model(&models.ConfigAnalysis{}).Scan(&pushed).Error + Expect(err).ToNot(HaveOccurred()) + Expect(pushed).To(BeZero()) + } + + var analyses int + err := upstreamCtx.DB().Select("COUNT(*)").Model(&models.ConfigAnalysis{}).Scan(&analyses).Error + Expect(err).ToNot(HaveOccurred()) + Expect(analyses).To(BeZero()) + + count, err := upstream.SyncConfigAnalyses(DefaultContext, upstreamConf, 10) + Expect(err).ToNot(HaveOccurred()) + + err = upstreamCtx.DB().Select("COUNT(*)").Model(&models.ConfigAnalysis{}).Scan(&analyses).Error + Expect(err).ToNot(HaveOccurred()) + Expect(analyses).To(Equal(count)) + + { + var pending int + err := DefaultContext.DB().Select("COUNT(*)").Where("is_pushed = false").Model(&models.ConfigAnalysis{}).Scan(&pending).Error + Expect(err).ToNot(HaveOccurred()) + Expect(pending).To(BeZero()) + } + }) + + ginkgo.It("should stop echo server ", func() { + echoCloser() + }) + + ginkgo.It("should drop upstream database ", func() { + drop() + }) +}) diff --git a/upstream/commands.go b/upstream/commands.go new file mode 100644 index 00000000..38495243 --- /dev/null +++ b/upstream/commands.go @@ -0,0 +1,124 @@ +package upstream + +import ( + "errors" + "fmt" + + "github.com/flanksource/commons/logger" + "github.com/flanksource/duty/context" + "github.com/flanksource/duty/models" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +func getAgent(ctx context.Context, name string) (*models.Agent, error) { + var t models.Agent + tx := ctx.DB().Where("name = ?", name).First(&t) + return &t, tx.Error +} + +func createAgent(ctx context.Context, name string) (*models.Agent, error) { + a := models.Agent{Name: name} + tx := ctx.DB().Create(&a) + return &a, tx.Error +} + +func GetOrCreateAgent(ctx context.Context, name string) (*models.Agent, error) { + a, err := getAgent(ctx, name) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + newAgent, err := createAgent(ctx, name) + if err != nil { + return nil, fmt.Errorf("failed to create agent: %w", err) + } + return newAgent, nil + } + return nil, err + } + + return a, nil +} + +func InsertUpstreamMsg(ctx context.Context, req *PushData) error { + batchSize := 100 + db := ctx.DB() + if len(req.Topologies) > 0 { + if err := db.Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(req.Topologies, batchSize).Error; err != nil { + return fmt.Errorf("error upserting topologies: %w", err) + } + } + + if len(req.Canaries) > 0 { + if err := db.Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(req.Canaries, batchSize).Error; err != nil { + return fmt.Errorf("error upserting canaries: %w", err) + } + } + + // components are inserted one by one, instead of in a batch, because of the foreign key constraint with itself. + for _, c := range req.Components { + if err := db.Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(req.Components, batchSize).Error; err != nil { + logger.Errorf("error upserting component (id=%s): %v", c.ID, err) + } + } + + if len(req.ComponentRelationships) > 0 { + cols := []clause.Column{{Name: "component_id"}, {Name: "relationship_id"}, {Name: "selector_id"}} + if err := db.Clauses(clause.OnConflict{UpdateAll: true, Columns: cols}).CreateInBatches(req.ComponentRelationships, batchSize).Error; err != nil { + return fmt.Errorf("error upserting component_relationships: %w", err) + } + } + + if len(req.ConfigScrapers) > 0 { + if err := db.Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(req.ConfigScrapers, batchSize).Error; err != nil { + return fmt.Errorf("error upserting config scrapers: %w", err) + } + } + + // config items are inserted one by one, instead of in a batch, because of the foreign key constraint with itself. + for _, ci := range req.ConfigItems { + if err := db.Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(&ci, batchSize).Error; err != nil { + logger.Errorf("error upserting config item (id=%s): %v", ci.ID, err) + } + } + + if len(req.ConfigRelationships) > 0 { + cols := []clause.Column{{Name: "related_id"}, {Name: "config_id"}, {Name: "selector_id"}} + if err := db.Clauses(clause.OnConflict{UpdateAll: true, Columns: cols}).CreateInBatches(req.ConfigRelationships, batchSize).Error; err != nil { + return fmt.Errorf("error upserting config_relationships: %w", err) + } + } + + if len(req.ConfigComponentRelationships) > 0 { + cols := []clause.Column{{Name: "component_id"}, {Name: "config_id"}} + if err := db.Clauses(clause.OnConflict{UpdateAll: true, Columns: cols}).CreateInBatches(req.ConfigComponentRelationships, batchSize).Error; err != nil { + return fmt.Errorf("error upserting config_component_relationships: %w", err) + } + } + + if len(req.ConfigChanges) > 0 { + if err := db.Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(req.ConfigChanges, batchSize).Error; err != nil { + return fmt.Errorf("error upserting config_changes: %w", err) + } + } + + if len(req.ConfigAnalysis) > 0 { + if err := db.Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(req.ConfigAnalysis, batchSize).Error; err != nil { + return fmt.Errorf("error upserting config_analysis: %w", err) + } + } + + if len(req.Checks) > 0 { + if err := db.Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(req.Checks, batchSize).Error; err != nil { + return fmt.Errorf("error upserting checks: %w", err) + } + } + + if len(req.CheckStatuses) > 0 { + cols := []clause.Column{{Name: "check_id"}, {Name: "time"}} + if err := db.Clauses(clause.OnConflict{UpdateAll: true, Columns: cols}).CreateInBatches(req.CheckStatuses, batchSize).Error; err != nil { + return fmt.Errorf("error upserting check_statuses: %w", err) + } + } + + return nil +} diff --git a/upstream/controllers.go b/upstream/controllers.go new file mode 100644 index 00000000..1118a933 --- /dev/null +++ b/upstream/controllers.go @@ -0,0 +1,128 @@ +package upstream + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/flanksource/commons/collections" + "github.com/flanksource/commons/logger" + "github.com/flanksource/duty/api" + "github.com/flanksource/duty/context" + "github.com/flanksource/duty/query" + "github.com/google/uuid" + "github.com/labstack/echo/v4" + "github.com/patrickmn/go-cache" + "go.opentelemetry.io/otel/attribute" +) + +// PullHandler returns a handler that returns all the ids of items it has received from the requested agent. +func PullHandler(allowedTables []string) func(echo.Context) error { + return func(c echo.Context) error { + ctx := c.Request().Context().(context.Context) + var req PaginateRequest + if err := c.Bind(&req); err != nil { + return c.JSON(http.StatusBadRequest, api.HTTPError{Error: err.Error()}) + } + + reqJSON, _ := json.Marshal(req) + ctx.GetSpan().SetAttributes(attribute.String("upstream.pull.paginate-request", string(reqJSON))) + + if !collections.Contains(allowedTables, req.Table) { + return c.JSON(http.StatusForbidden, api.HTTPError{Error: fmt.Sprintf("table=%s is not allowed", req.Table)}) + } + + agentName := c.Param("agent_name") + agent, err := query.FindAgent(ctx, agentName) + if err != nil { + return c.JSON(http.StatusInternalServerError, api.HTTPError{Error: err.Error(), Message: "failed to get agent"}) + } else if agent == nil { + return c.JSON(http.StatusNotFound, api.HTTPError{Message: fmt.Sprintf("agent(name=%s) not found", agentName)}) + } + + resp, err := query.GetAllResourceIDsOfAgent(ctx, req.Table, req.From, req.Size, agent.ID) + if err != nil { + return c.JSON(http.StatusInternalServerError, api.HTTPError{Error: err.Error(), Message: "failed to get resource ids"}) + } + + return c.JSON(http.StatusOK, resp) + } +} + +// PushHandler returns an echo handler that saves the push data from agents. +func PushHandler(agentIDCache *cache.Cache) func(echo.Context) error { + return func(c echo.Context) error { + ctx := c.Request().Context().(context.Context) + + var req PushData + err := json.NewDecoder(c.Request().Body).Decode(&req) + if err != nil { + return c.JSON(http.StatusBadRequest, api.HTTPError{Error: err.Error(), Message: "invalid json request"}) + } + + ctx.GetSpan().SetAttributes(attribute.Int("upstream.push.msg-count", req.Count())) + + req.AgentName = strings.TrimSpace(req.AgentName) + if req.AgentName == "" { + return c.JSON(http.StatusBadRequest, api.HTTPError{Error: "agent name is required", Message: "agent name is required"}) + } + + agentID, ok := agentIDCache.Get(req.AgentName) + if !ok { + agent, err := GetOrCreateAgent(ctx, req.AgentName) + if err != nil { + return c.JSON(http.StatusBadRequest, api.HTTPError{ + Error: err.Error(), + Message: "Error while creating/fetching agent", + }) + } + agentID = agent.ID + agentIDCache.Set(req.AgentName, agentID, cache.DefaultExpiration) + } + + req.PopulateAgentID(agentID.(uuid.UUID)) + + logger.Tracef("Inserting push data %s", req.String()) + if err := InsertUpstreamMsg(ctx, &req); err != nil { + return c.JSON(http.StatusInternalServerError, api.HTTPError{Error: err.Error(), Message: "failed to upsert upstream message"}) + } + + return nil + } +} + +// StatusHandler returns a handler that returns the summary of all ids the upstream has received. +func StatusHandler(allowedTables []string) func(echo.Context) error { + return func(c echo.Context) error { + ctx := c.Request().Context().(context.Context) + var req PaginateRequest + if err := c.Bind(&req); err != nil { + return c.JSON(http.StatusBadRequest, api.HTTPError{Error: err.Error()}) + } + + reqJSON, _ := json.Marshal(req) + ctx.GetSpan().SetAttributes(attribute.String("upstream.status.paginate-request", string(reqJSON))) + + if !collections.Contains(allowedTables, req.Table) { + return c.JSON(http.StatusForbidden, api.HTTPError{Error: fmt.Sprintf("table=%s is not allowed", req.Table)}) + } + + var agentName = c.Param("agent_name") + agent, err := query.FindAgent(ctx, agentName) + if err != nil { + return c.JSON(http.StatusInternalServerError, api.HTTPError{Error: err.Error(), Message: "failed to get agent"}) + } + + if agent == nil { + return c.JSON(http.StatusNotFound, api.HTTPError{Message: fmt.Sprintf("agent(name=%s) not found", agentName)}) + } + + response, err := GetPrimaryKeysHash(ctx, req, agent.ID) + if err != nil { + return c.JSON(http.StatusInternalServerError, api.HTTPError{Error: err.Error(), Message: "failed to push status response"}) + } + + return c.JSON(http.StatusOK, response) + } +}