From 0c36fc7b31a0af4d7e0dfd87561682fc67c4f40a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Carpintero?= Date: Mon, 9 Dec 2024 18:43:35 +0100 Subject: [PATCH] Improve context management and graceful shutdown --- bootstrapper.go | 23 ++- config/config.go | 4 +- go.mod | 1 + go.sum | 2 + internal/http/check.go | 168 ++++++++++-------- internal/http/check_test.go | 333 ++++++++++++++++++++++++++++++++++++ 6 files changed, 441 insertions(+), 90 deletions(-) create mode 100644 internal/http/check_test.go diff --git a/bootstrapper.go b/bootstrapper.go index 2e3e6c4..6dc765a 100644 --- a/bootstrapper.go +++ b/bootstrapper.go @@ -131,20 +131,19 @@ func NewCheck(name string, checker Checker) Check { conf.Check.Target = runTarget conf.Check.Opts = options c = newLocalCheck(name, checker, logger, conf, json) + } else if conf.Port != nil { + logger.Debug("Http mode") + l := logging.BuildLoggerWithConfigAndFields(conf.Log, log.Fields{ + // "checkTypeName": "TODO", + // "checkTypeVersion": "TODO", + // "component": "checks", + }) + c = http.NewCheck(name, checker, l, conf) } else { - if conf.Port > 0 { - logger.Debug("Http mode") - l := logging.BuildLoggerWithConfigAndFields(conf.Log, log.Fields{ - // "checkTypeName": "TODO", - // "checkTypeVersion": "TODO", - // "component": "checks", - }) - c = http.NewCheck(name, checker, l, conf) - } else { - logger.Debug("Push mode") - c = push.NewCheckWithConfig(name, checker, logger, conf) - } + logger.Debug("Push mode") + c = push.NewCheckWithConfig(name, checker, logger, conf) } + cachedConfig = conf return c } diff --git a/config/config.go b/config/config.go index 3035253..ea91387 100644 --- a/config/config.go +++ b/config/config.go @@ -65,7 +65,7 @@ type Config struct { Log LogConfig `toml:"Log"` CommMode string `toml:"CommMode"` Push rest.PusherConfig `toml:"Push"` - Port int + Port *int `toml:"Port"` AllowPrivateIPs *bool `toml:"AllowPrivateIps"` RequiredVars map[string]string `toml:"RequiredVars"` } @@ -128,7 +128,7 @@ func overrideCommConfigEnvVars(c *Config) { if port != "" { p, err := strconv.Atoi(port) if err == nil { - c.Port = p + c.Port = &p } } diff --git a/go.mod b/go.mod index 1730554..3730c13 100644 --- a/go.mod +++ b/go.mod @@ -53,6 +53,7 @@ require ( go.opentelemetry.io/otel v1.24.0 // indirect go.opentelemetry.io/otel/metric v1.24.0 // indirect go.opentelemetry.io/otel/trace v1.24.0 // indirect + go.uber.org/goleak v1.3.0 // indirect golang.org/x/crypto v0.21.0 // indirect golang.org/x/mod v0.14.0 // indirect golang.org/x/net v0.22.0 // indirect diff --git a/go.sum b/go.sum index 905118f..3a14532 100644 --- a/go.sum +++ b/go.sum @@ -155,6 +155,8 @@ go.opentelemetry.io/otel/metric v1.24.0 h1:6EhoGWWK28x1fbpA4tYTOWBkPefTDQnb8WSGX go.opentelemetry.io/otel/metric v1.24.0/go.mod h1:VYhLe1rFfxuTXLgj4CBiyz+9WYBA8pNGJgDcSFRKBco= go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y1YELI= go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= diff --git a/internal/http/check.go b/internal/http/check.go index a2d69ea..2b5ce09 100644 --- a/internal/http/check.go +++ b/internal/http/check.go @@ -1,5 +1,5 @@ /* -Copyright 2019 Adevinta +Copyright 2024 Adevinta */ package http @@ -30,90 +30,104 @@ type Check struct { checker Checker config *config.Config port int - ctx context.Context - cancel context.CancelFunc + server *http.Server exitSignal chan os.Signal } -// RunAndServe implements the behavior needed by the sdk for a check runner to -// execute a check. -func (c *Check) RunAndServe() { - http.HandleFunc("/run", func(w http.ResponseWriter, r *http.Request) { - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "error reading request body", http.StatusBadRequest) - return - } - var job Job - err = json.Unmarshal(body, &job) - if err != nil { - w.WriteHeader(500) - return - } +// ServeHTTP implements an http POST handler that receives a JSON enconde Job, and returns an +// agent.State JSON enconded response. +func (c *Check) ServeHTTP(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "error reading request body", http.StatusBadRequest) + return + } + var job Job + err = json.Unmarshal(body, &job) + if err != nil { + w.WriteHeader(500) + return + } - logger := c.Logger.WithFields(log.Fields{ - "target": job.Target, - "checkID": job.CheckID, - }) - ctx := context.WithValue(c.ctx, "logger", logger) - checkState := &State{ - state: agent.State{ - Report: report.Report{ - CheckData: report.CheckData{ - CheckID: job.CheckID, - StartTime: time.Now(), - ChecktypeName: c.config.Check.CheckTypeName, - ChecktypeVersion: c.config.Check.CheckTypeVersion, - Options: job.Options, - Target: job.Target, - }, - ResultData: report.ResultData{}, + logger := c.Logger.WithFields(log.Fields{ + "target": job.Target, + "checkID": job.CheckID, + }) + ctx := context.WithValue(r.Context(), "logger", logger) + checkState := &State{ + state: agent.State{ + Report: report.Report{ + CheckData: report.CheckData{ + CheckID: job.CheckID, + StartTime: job.StartTime, // TODO: Is this correct or should be time.Now() + ChecktypeName: c.config.Check.CheckTypeName, + ChecktypeVersion: c.config.Check.CheckTypeVersion, + Options: job.Options, + Target: job.Target, }, + ResultData: report.ResultData{}, }, - } + }, + } - runtimeState := state.State{ - ResultData: &checkState.state.Report.ResultData, - ProgressReporter: state.ProgressReporterHandler(checkState.SetProgress), - } - logger.WithField("opts", job.Options).Info("Starting check") - err = c.checker.Run(ctx, job.Target, job.AssetType, job.Options, runtimeState) - c.checker.CleanUp(context.Background(), job.Target, job.AssetType, job.Options) - checkState.state.Report.CheckData.EndTime = time.Now() - elapsedTime := time.Since(checkState.state.Report.CheckData.StartTime) - // If an error has been returned, we set the correct status. - if err != nil { - if errors.Is(err, context.Canceled) { - checkState.state.Status = agent.StatusAborted - } else if errors.Is(err, state.ErrAssetUnreachable) { - checkState.state.Status = agent.StatusInconclusive - } else if errors.Is(err, state.ErrNonPublicAsset) { - checkState.state.Status = agent.StatusInconclusive - } else { - c.Logger.WithError(err).Error("Error running check") - checkState.state.Status = agent.StatusFailed - checkState.state.Report.Error = err.Error() - } + runtimeState := state.State{ + ResultData: &checkState.state.Report.ResultData, + ProgressReporter: state.ProgressReporterHandler(checkState.SetProgress), + } + logger.WithField("opts", job.Options).Info("Starting check") + err = c.checker.Run(ctx, job.Target, job.AssetType, job.Options, runtimeState) + c.checker.CleanUp(ctx, job.Target, job.AssetType, job.Options) + checkState.state.Report.CheckData.EndTime = time.Now() + elapsedTime := time.Since(checkState.state.Report.CheckData.StartTime) + // If an error has been returned, we set the correct status. + if err != nil { + if errors.Is(err, context.Canceled) { + checkState.state.Status = agent.StatusAborted + } else if errors.Is(err, state.ErrAssetUnreachable) { + checkState.state.Status = agent.StatusInconclusive + } else if errors.Is(err, state.ErrNonPublicAsset) { + checkState.state.Status = agent.StatusInconclusive } else { - checkState.state.Status = agent.StatusFinished + c.Logger.WithError(err).Error("Error running check") + checkState.state.Status = agent.StatusFailed + checkState.state.Report.Error = err.Error() } - checkState.state.Report.Status = checkState.state.Status + } else { + checkState.state.Status = agent.StatusFinished + } + checkState.state.Report.Status = checkState.state.Status + + logger.WithField("seconds", elapsedTime.Seconds()).WithField("state", checkState.state.Status).Info("Check finished") - logger.WithField("seconds", elapsedTime.Seconds()).WithField("state", checkState.state.Status).Info("Check finished") + // Initialize sync point for the checker and the push state to be finished. + out, err := json.Marshal(checkState.state) + if err != nil { + logger.WithError(err).Error("error marshalling the check state") + http.Error(w, "error marshalling the check state", http.StatusInternalServerError) + return + } + w.Write(out) +} - // Initialize sync point for the checker and the push state to be finished. - out, err := json.Marshal(checkState.state) - if err != nil { - logger.WithError(err).Error("error marshalling the check state") - http.Error(w, "error marshalling the check state", http.StatusInternalServerError) - return +// RunAndServe implements the behavior needed by the sdk for a check runner to +// execute a check. +func (c *Check) RunAndServe() { + http.HandleFunc("/run", c.ServeHTTP) + c.Logger.Info(fmt.Sprintf("Listening at %s", c.server.Addr)) + go func() { + if err := c.server.ListenAndServe(); err != nil { + // handle err } - w.Write(out) - }) + }() - addr := fmt.Sprintf(":%d", c.port) - c.Logger.Info(fmt.Sprintf("Listening at %s", addr)) - log.Fatal(http.ListenAndServe(addr, nil)) + s := <-c.exitSignal + + c.Logger.WithField("signal", s.String()).Info("Stopping server") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) // TODO: Allow configure value. + defer cancel() + if err := c.server.Shutdown(ctx); err != nil { + c.Logger.WithError(err).Error("Shutting down server") + } } type Job struct { @@ -129,23 +143,24 @@ type Job struct { RunTime int64 } -// Shutdown is needed to fullfil the check interface but we don't need to do +// Shutdown is needed to fulfil the check interface but we don't need to do // anything in this case. func (c *Check) Shutdown() error { + c.exitSignal <- syscall.SIGTERM return nil } -// NewCheck creates new check to be run from the command line without having an agent. +// NewCheck creates new check to be run from the command line without having an agent. func NewCheck(name string, checker Checker, logger *log.Entry, conf *config.Config) *Check { c := &Check{ Name: name, Logger: logger, config: conf, exitSignal: make(chan os.Signal, 1), - port: conf.Port, + port: *conf.Port, } + c.server = &http.Server{Addr: fmt.Sprintf(":%d", c.port)} signal.Notify(c.exitSignal, syscall.SIGINT, syscall.SIGTERM) - c.ctx, c.cancel = context.WithCancel(context.Background()) c.checker = checker return c } @@ -161,6 +176,7 @@ type Checker interface { CleanUp(ctx context.Context, target, assetType, opts string) } +// SetProgress updates the progress of the state. func (p *State) SetProgress(progress float32) { if p.state.Status == agent.StatusRunning && progress > p.state.Progress { p.state.Progress = progress diff --git a/internal/http/check_test.go b/internal/http/check_test.go new file mode 100644 index 0000000..8e4b230 --- /dev/null +++ b/internal/http/check_test.go @@ -0,0 +1,333 @@ +/* +Copyright 2019 Adevinta +*/ + +package http + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "sync" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "go.uber.org/goleak" + + log "github.com/sirupsen/logrus" + + "github.com/adevinta/vulcan-check-sdk/agent" + "github.com/adevinta/vulcan-check-sdk/config" + "github.com/adevinta/vulcan-check-sdk/internal/logging" + "github.com/adevinta/vulcan-check-sdk/state" + report "github.com/adevinta/vulcan-report" +) + +type CheckerHandleRun func(ctx context.Context, target, assetType, opts string, s state.State) error + +// Run is used as adapter to satisfy the method with same name in interface Checker. +func (handler CheckerHandleRun) Run(ctx context.Context, target, assetType string, opts string, s state.State) error { + return (handler(ctx, target, assetType, opts, s)) +} + +// CheckerHandleCleanUp func type to specify a CleanUp handler function for a checker. +type CheckerHandleCleanUp func(ctx context.Context, target, assetType, opts string) + +// CleanUp is used as adapter to satisfy the method with same name in interface Checker. +func (handler CheckerHandleCleanUp) CleanUp(ctx context.Context, target, assetType, opts string) { + (handler(ctx, target, assetType, opts)) +} + +// NewCheckFromHandler creates a new check given a checker run handler. +func NewCheckFromHandlerWithConfig(name string, run CheckerHandleRun, clean CheckerHandleCleanUp, conf *config.Config, l *log.Entry) *Check { + if clean == nil { + clean = func(ctx context.Context, target, assetType, opts string) {} + } + checkerAdapter := struct { + CheckerHandleRun + CheckerHandleCleanUp + }{ + run, + clean, + } + return NewCheck(name, checkerAdapter, l, conf) +} + +type httpTest struct { + name string + args httpIntParams + want map[string]agent.State + wantResourceState interface{} +} + +type httpIntParams struct { + checkRunner CheckerHandleRun + checkCleaner func(resourceToClean interface{}, ctx context.Context, target, assetType, optJSON string) + resourceToClean interface{} + checkName string + config *config.Config + jobs map[string]Job +} + +// sleepCheckRunner implements a check that sleeps based on the options and generates inconclusive in case of a target with that name. +func sleepCheckRunner(ctx context.Context, target, assetType, optJSON string, st state.State) (err error) { + log := logging.BuildRootLog("TestChecker") + log.Debug("Check running") + st.SetProgress(0.1) + type t struct { + SleepTime int + } + opt := t{} + if optJSON == "" { + return errors.New("error: missing sleep time") + } + if err := json.Unmarshal([]byte(optJSON), &opt); err != nil { + return err + } + if target == "inconclusive" { + return state.ErrAssetUnreachable + } + if opt.SleepTime <= 0 { + return errors.New("error: missing or 0 sleep time") + } + log.Debugf("going sleep %v seconds.", strconv.Itoa(opt.SleepTime)) + + select { + case <-time.After(time.Duration(opt.SleepTime) * time.Second): + log.Debugf("slept successfully %s seconds", strconv.Itoa(opt.SleepTime)) + case <-ctx.Done(): + log.Info("Check aborted") + } + st.AddVulnerabilities(report.Vulnerability{ + Summary: "Summary", + Description: "Test Vulnerability", + }) + return nil +} + +func TestIntegrationHttpMode(t *testing.T) { + port := 8888 + startTime := time.Now() + intTests := []httpTest{ + { + name: "HappyPath", + args: httpIntParams{ + config: &config.Config{ + Check: config.CheckConfig{ + CheckTypeName: "checkTypeName", + }, + Log: config.LogConfig{ + LogFmt: "text", + LogLevel: "debug", + }, + Port: &port, + }, + checkRunner: sleepCheckRunner, + jobs: map[string]Job{ + "checkHappy": { + CheckID: "checkHappy", + Options: `{"SleepTime": 1}`, + Target: "www.example.com", + AssetType: "Hostname", + StartTime: startTime, + }, + "checkDeadline": { + CheckID: "checkDeadline", + Options: `{"SleepTime": 10}`, + Target: "www.example.com", + AssetType: "Hostname", + StartTime: startTime, + }, + "checkInconclusive": { + CheckID: "checkInconclusive", + Options: `{"SleepTime": 1}`, + Target: "inconclusive", + AssetType: "Hostname", + StartTime: startTime, + }, + "checkFailed": { + CheckID: "checkFailed", + Options: `{}`, + Target: "www.example.com", + AssetType: "Hostname", + StartTime: startTime, + }, + }, + resourceToClean: map[string]string{"key": "initial"}, + checkCleaner: func(resource interface{}, ctx context.Context, target, assetType, opt string) { + r := resource.(map[string]string) + r["key"] = "cleaned" + }, + }, + wantResourceState: map[string]string{"key": "cleaned"}, + want: map[string]agent.State{ + "checkHappy": { + Status: agent.StatusFinished, + Report: report.Report{ + CheckData: report.CheckData{ + CheckID: "checkHappy", + ChecktypeName: "checkTypeName", + ChecktypeVersion: "", + Target: "www.example.com", + Options: `{"SleepTime": 1}`, + Status: agent.StatusFinished, + StartTime: startTime, + EndTime: time.Time{}, + }, + ResultData: report.ResultData{ + Vulnerabilities: []report.Vulnerability{ + { + Description: "Test Vulnerability", + Summary: "Summary", + }, + }, + Error: "", + Data: nil, + Notes: "", + }, + }}, + "checkDeadline": { + Status: agent.StatusAborted, + }, + "checkInconclusive": { + Status: agent.StatusInconclusive, + Report: report.Report{ + CheckData: report.CheckData{ + CheckID: "checkInconclusive", + ChecktypeName: "checkTypeName", + ChecktypeVersion: "", + Target: "inconclusive", + Options: `{"SleepTime": 1}`, + Status: agent.StatusInconclusive, + StartTime: startTime, + EndTime: time.Time{}, + }, + }, + }, + "checkFailed": { + Status: agent.StatusFailed, + Report: report.Report{ + CheckData: report.CheckData{ + CheckID: "checkFailed", + ChecktypeName: "checkTypeName", + ChecktypeVersion: "", + Target: "www.example.com", + Options: `{}`, + Status: agent.StatusFailed, + StartTime: startTime, + EndTime: time.Time{}, + }, + ResultData: report.ResultData{ + Error: "error: missing or 0 sleep time", + }, + }, + }, + }, + }, + } + + defer goleak.VerifyNone(t) + + for _, tt := range intTests { + tt := tt + t.Run(tt.name, func(t2 *testing.T) { + conf := tt.args.config + var cleaner func(ctx context.Context, target, assetType, opts string) + if tt.args.checkCleaner != nil { + cleaner = func(ctx context.Context, target, assetType, opts string) { + tt.args.checkCleaner(tt.args.resourceToClean, ctx, target, assetType, opts) + } + } + l := logging.BuildRootLog("httpCheck") + c := NewCheckFromHandlerWithConfig(tt.args.checkName, tt.args.checkRunner, cleaner, conf, l) + go c.RunAndServe() + client := &http.Client{} + url := fmt.Sprintf("http://localhost:%d/run", *tt.args.config.Port) + + type not struct { + check string + resp agent.State + } + + // ch will receibe the results of the concurrent job executions + ch := make(chan not, len(tt.args.jobs)) + wg := sync.WaitGroup{} + + // Runs each job in a go routine with a 3 seconds deadline. + for key, job := range tt.args.jobs { + wg.Add(1) + go func(key string, job Job) { + defer wg.Done() + var err error + n := not{ + check: key, + } + defer func() { + ch <- n + }() + cc, err := json.Marshal(job) + if err != nil { + l.Error("Marshal error", "error", err) + return + } + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(cc)) + if err != nil { + l.Error("NewRequestWithContext error", "error", err) + return + } + req.Header.Add("Content-Type", "application/json") + resp, err := client.Do(req) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + n.resp = agent.State{Status: agent.StatusAborted} + return + } + l.Error("request error", "error", err) + return + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + l.Error("failed to read response body", "error", err) + return + } + r := agent.State{} + err = json.Unmarshal(body, &r) + if err != nil { + l.Error("Unable to unmarshal response", "error", err) + return + } + + // Compare resource to clean up state with wanted state. + diff := cmp.Diff(tt.wantResourceState, tt.args.resourceToClean) + if diff != "" { + t.Errorf("Error want resource to clean state != got. Diff %s", diff) + } + n.resp = r + }(key, job) + } + wg.Wait() + close(ch) + + results := map[string]agent.State{} + for x := range ch { + results[x.check] = x.resp + } + + diff := cmp.Diff(results, tt.want, cmpopts.IgnoreFields(report.CheckData{}, "EndTime")) + if diff != "" { + t.Errorf("Error in test %s. diffs %+v", tt.name, diff) + } + c.Shutdown() + }) + } +}