Skip to content

Commit

Permalink
Improve context management and graceful shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
jesusfcr committed Dec 9, 2024
1 parent 59365dd commit 0c36fc7
Show file tree
Hide file tree
Showing 6 changed files with 441 additions and 90 deletions.
23 changes: 11 additions & 12 deletions bootstrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,20 +131,19 @@ func NewCheck(name string, checker Checker) Check {
conf.Check.Target = runTarget
conf.Check.Opts = options
c = newLocalCheck(name, checker, logger, conf, json)
} else if conf.Port != nil {
logger.Debug("Http mode")
l := logging.BuildLoggerWithConfigAndFields(conf.Log, log.Fields{
// "checkTypeName": "TODO",
// "checkTypeVersion": "TODO",
// "component": "checks",
})
c = http.NewCheck(name, checker, l, conf)
} else {
if conf.Port > 0 {
logger.Debug("Http mode")
l := logging.BuildLoggerWithConfigAndFields(conf.Log, log.Fields{
// "checkTypeName": "TODO",
// "checkTypeVersion": "TODO",
// "component": "checks",
})
c = http.NewCheck(name, checker, l, conf)
} else {
logger.Debug("Push mode")
c = push.NewCheckWithConfig(name, checker, logger, conf)
}
logger.Debug("Push mode")
c = push.NewCheckWithConfig(name, checker, logger, conf)
}

cachedConfig = conf
return c
}
Expand Down
4 changes: 2 additions & 2 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ type Config struct {
Log LogConfig `toml:"Log"`
CommMode string `toml:"CommMode"`
Push rest.PusherConfig `toml:"Push"`
Port int
Port *int `toml:"Port"`
AllowPrivateIPs *bool `toml:"AllowPrivateIps"`
RequiredVars map[string]string `toml:"RequiredVars"`
}
Expand Down Expand Up @@ -128,7 +128,7 @@ func overrideCommConfigEnvVars(c *Config) {
if port != "" {
p, err := strconv.Atoi(port)
if err == nil {
c.Port = p
c.Port = &p
}
}

Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ require (
go.opentelemetry.io/otel v1.24.0 // indirect
go.opentelemetry.io/otel/metric v1.24.0 // indirect
go.opentelemetry.io/otel/trace v1.24.0 // indirect
go.uber.org/goleak v1.3.0 // indirect
golang.org/x/crypto v0.21.0 // indirect
golang.org/x/mod v0.14.0 // indirect
golang.org/x/net v0.22.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ go.opentelemetry.io/otel/metric v1.24.0 h1:6EhoGWWK28x1fbpA4tYTOWBkPefTDQnb8WSGX
go.opentelemetry.io/otel/metric v1.24.0/go.mod h1:VYhLe1rFfxuTXLgj4CBiyz+9WYBA8pNGJgDcSFRKBco=
go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y1YELI=
go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
Expand Down
168 changes: 92 additions & 76 deletions internal/http/check.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright 2019 Adevinta
Copyright 2024 Adevinta
*/

package http
Expand Down Expand Up @@ -30,90 +30,104 @@ type Check struct {
checker Checker
config *config.Config
port int
ctx context.Context
cancel context.CancelFunc
server *http.Server
exitSignal chan os.Signal
}

// RunAndServe implements the behavior needed by the sdk for a check runner to
// execute a check.
func (c *Check) RunAndServe() {
http.HandleFunc("/run", func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "error reading request body", http.StatusBadRequest)
return
}
var job Job
err = json.Unmarshal(body, &job)
if err != nil {
w.WriteHeader(500)
return
}
// ServeHTTP implements an http POST handler that receives a JSON enconde Job, and returns an
// agent.State JSON enconded response.
func (c *Check) ServeHTTP(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "error reading request body", http.StatusBadRequest)
return
}
var job Job
err = json.Unmarshal(body, &job)
if err != nil {
w.WriteHeader(500)
return
}

logger := c.Logger.WithFields(log.Fields{
"target": job.Target,
"checkID": job.CheckID,
})
ctx := context.WithValue(c.ctx, "logger", logger)
checkState := &State{
state: agent.State{
Report: report.Report{
CheckData: report.CheckData{
CheckID: job.CheckID,
StartTime: time.Now(),
ChecktypeName: c.config.Check.CheckTypeName,
ChecktypeVersion: c.config.Check.CheckTypeVersion,
Options: job.Options,
Target: job.Target,
},
ResultData: report.ResultData{},
logger := c.Logger.WithFields(log.Fields{
"target": job.Target,
"checkID": job.CheckID,
})
ctx := context.WithValue(r.Context(), "logger", logger)
checkState := &State{
state: agent.State{
Report: report.Report{
CheckData: report.CheckData{
CheckID: job.CheckID,
StartTime: job.StartTime, // TODO: Is this correct or should be time.Now()
ChecktypeName: c.config.Check.CheckTypeName,
ChecktypeVersion: c.config.Check.CheckTypeVersion,
Options: job.Options,
Target: job.Target,
},
ResultData: report.ResultData{},
},
}
},
}

runtimeState := state.State{
ResultData: &checkState.state.Report.ResultData,
ProgressReporter: state.ProgressReporterHandler(checkState.SetProgress),
}
logger.WithField("opts", job.Options).Info("Starting check")
err = c.checker.Run(ctx, job.Target, job.AssetType, job.Options, runtimeState)
c.checker.CleanUp(context.Background(), job.Target, job.AssetType, job.Options)
checkState.state.Report.CheckData.EndTime = time.Now()
elapsedTime := time.Since(checkState.state.Report.CheckData.StartTime)
// If an error has been returned, we set the correct status.
if err != nil {
if errors.Is(err, context.Canceled) {
checkState.state.Status = agent.StatusAborted
} else if errors.Is(err, state.ErrAssetUnreachable) {
checkState.state.Status = agent.StatusInconclusive
} else if errors.Is(err, state.ErrNonPublicAsset) {
checkState.state.Status = agent.StatusInconclusive
} else {
c.Logger.WithError(err).Error("Error running check")
checkState.state.Status = agent.StatusFailed
checkState.state.Report.Error = err.Error()
}
runtimeState := state.State{
ResultData: &checkState.state.Report.ResultData,
ProgressReporter: state.ProgressReporterHandler(checkState.SetProgress),
}
logger.WithField("opts", job.Options).Info("Starting check")
err = c.checker.Run(ctx, job.Target, job.AssetType, job.Options, runtimeState)
c.checker.CleanUp(ctx, job.Target, job.AssetType, job.Options)
checkState.state.Report.CheckData.EndTime = time.Now()
elapsedTime := time.Since(checkState.state.Report.CheckData.StartTime)
// If an error has been returned, we set the correct status.
if err != nil {
if errors.Is(err, context.Canceled) {
checkState.state.Status = agent.StatusAborted
} else if errors.Is(err, state.ErrAssetUnreachable) {
checkState.state.Status = agent.StatusInconclusive
} else if errors.Is(err, state.ErrNonPublicAsset) {
checkState.state.Status = agent.StatusInconclusive
} else {
checkState.state.Status = agent.StatusFinished
c.Logger.WithError(err).Error("Error running check")
checkState.state.Status = agent.StatusFailed
checkState.state.Report.Error = err.Error()
}
checkState.state.Report.Status = checkState.state.Status
} else {
checkState.state.Status = agent.StatusFinished
}
checkState.state.Report.Status = checkState.state.Status

logger.WithField("seconds", elapsedTime.Seconds()).WithField("state", checkState.state.Status).Info("Check finished")

logger.WithField("seconds", elapsedTime.Seconds()).WithField("state", checkState.state.Status).Info("Check finished")
// Initialize sync point for the checker and the push state to be finished.
out, err := json.Marshal(checkState.state)
if err != nil {
logger.WithError(err).Error("error marshalling the check state")
http.Error(w, "error marshalling the check state", http.StatusInternalServerError)
return
}
w.Write(out)
}

// Initialize sync point for the checker and the push state to be finished.
out, err := json.Marshal(checkState.state)
if err != nil {
logger.WithError(err).Error("error marshalling the check state")
http.Error(w, "error marshalling the check state", http.StatusInternalServerError)
return
// RunAndServe implements the behavior needed by the sdk for a check runner to
// execute a check.
func (c *Check) RunAndServe() {
http.HandleFunc("/run", c.ServeHTTP)
c.Logger.Info(fmt.Sprintf("Listening at %s", c.server.Addr))
go func() {
if err := c.server.ListenAndServe(); err != nil {
// handle err
}
w.Write(out)
})
}()

addr := fmt.Sprintf(":%d", c.port)
c.Logger.Info(fmt.Sprintf("Listening at %s", addr))
log.Fatal(http.ListenAndServe(addr, nil))
s := <-c.exitSignal

c.Logger.WithField("signal", s.String()).Info("Stopping server")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) // TODO: Allow configure value.
defer cancel()
if err := c.server.Shutdown(ctx); err != nil {
c.Logger.WithError(err).Error("Shutting down server")
}
}

type Job struct {
Expand All @@ -129,23 +143,24 @@ type Job struct {
RunTime int64
}

// Shutdown is needed to fullfil the check interface but we don't need to do
// Shutdown is needed to fulfil the check interface but we don't need to do
// anything in this case.
func (c *Check) Shutdown() error {
c.exitSignal <- syscall.SIGTERM
return nil
}

// NewCheck creates new check to be run from the command line without having an agent.
// NewCheck creates new check to be run from the command line without having an agent.
func NewCheck(name string, checker Checker, logger *log.Entry, conf *config.Config) *Check {
c := &Check{
Name: name,
Logger: logger,
config: conf,
exitSignal: make(chan os.Signal, 1),
port: conf.Port,
port: *conf.Port,
}
c.server = &http.Server{Addr: fmt.Sprintf(":%d", c.port)}
signal.Notify(c.exitSignal, syscall.SIGINT, syscall.SIGTERM)
c.ctx, c.cancel = context.WithCancel(context.Background())
c.checker = checker
return c
}
Expand All @@ -161,6 +176,7 @@ type Checker interface {
CleanUp(ctx context.Context, target, assetType, opts string)
}

// SetProgress updates the progress of the state.
func (p *State) SetProgress(progress float32) {
if p.state.Status == agent.StatusRunning && progress > p.state.Progress {
p.state.Progress = progress
Expand Down
Loading

0 comments on commit 0c36fc7

Please sign in to comment.