diff --git a/cmd/lava/internal/run/run.go b/cmd/lava/internal/run/run.go index 06d0638..7a79954 100644 --- a/cmd/lava/internal/run/run.go +++ b/cmd/lava/internal/run/run.go @@ -255,7 +255,7 @@ func engineRun(targetIdent string, checktype string) (engine.Report, error) { case err != nil && !errors.Is(err, fs.ErrNotExist): return nil, err case err == nil && info.IsDir(): - if *agentConfig.PullPolicy != agentconfig.PullPolicyIfNotPresent && *agentConfig.PullPolicy != agentconfig.PullPolicyNever { + if config.Get(agentConfig.PullPolicy) != agentconfig.PullPolicyIfNotPresent && config.Get(agentConfig.PullPolicy) != agentconfig.PullPolicyNever { return nil, errors.New("path checktypes only allow IfNotPresent and Never pull policies") } @@ -431,8 +431,8 @@ func writeOutputs(rep engine.Report) (report.ExitCode, error) { Severity: &runSeverity, ShowSeverity: showSeverity, Format: &runFmt, - OutputFile: runO, - Metrics: runMetrics, + OutputFile: &runO, + Metrics: &runMetrics, } metrics.Collect("severity", reportConfig.Severity) @@ -447,8 +447,8 @@ func writeOutputs(rep engine.Report) (report.ExitCode, error) { return 0, fmt.Errorf("render report: %w", err) } - if reportConfig.Metrics != "" { - if err = metrics.WriteFile(reportConfig.Metrics); err != nil { + if config.Get(reportConfig.Metrics) != "" { + if err = metrics.WriteFile(config.Get(reportConfig.Metrics)); err != nil { return 0, fmt.Errorf("write metrics: %w", err) } } diff --git a/cmd/lava/internal/scan/scan.go b/cmd/lava/internal/scan/scan.go index 4516a88..1d46cc5 100644 --- a/cmd/lava/internal/scan/scan.go +++ b/cmd/lava/internal/scan/scan.go @@ -6,7 +6,6 @@ package scan import ( "errors" "fmt" - "log/slog" "os" "runtime/debug" "time" @@ -70,9 +69,6 @@ var osExit = os.Exit // debugReadBuildInfo is used by tests to set the command version. var debugReadBuildInfo = debug.ReadBuildInfo -//// DefaultLogLevel default level for logging. -//const DefaultLogLevel = slog.LevelInfo - // runScan is the entry point of the scan command. func runScan(args []string) error { exitCode, err := scan(args) @@ -100,11 +96,7 @@ func scan(args []string) (int, error) { return 0, fmt.Errorf("parse config file: %w", err) } - var logLevel slog.Level - if cfg.LogLevel != nil { - logLevel = *cfg.LogLevel - } - base.LogLevel.Set(logLevel) + base.LogLevel.Set(config.Get(cfg.LogLevel)) bi, ok := debugReadBuildInfo() if !ok { @@ -116,15 +108,11 @@ func scan(args []string) (int, error) { return 0, fmt.Errorf("minimum required version %v", cfg.LavaVersion) } - var severity config.Severity - if cfg.ReportConfig.Severity != nil { - severity = *cfg.ReportConfig.Severity - } metrics.Collect("lava_version", bi.Main.Version) - metrics.Collect("config_version", cfg.LavaVersion) + metrics.Collect("config_version", config.Get(cfg.LavaVersion)) metrics.Collect("checktype_urls", cfg.ChecktypeURLs) metrics.Collect("targets", cfg.Targets) - metrics.Collect("severity", severity) + metrics.Collect("severity", config.Get(cfg.ReportConfig.Severity)) metrics.Collect("exclusion_count", len(cfg.ReportConfig.Exclusions)) eng, err := engine.New(cfg.AgentConfig, cfg.ChecktypeURLs) @@ -152,8 +140,8 @@ func scan(args []string) (int, error) { metrics.Collect("exit_code", exitCode) metrics.Collect("duration", time.Since(startTime).Seconds()) - if cfg.ReportConfig.Metrics != "" { - if err = metrics.WriteFile(cfg.ReportConfig.Metrics); err != nil { + if config.Get(cfg.ReportConfig.Metrics) != "" { + if err = metrics.WriteFile(config.Get(cfg.ReportConfig.Metrics)); err != nil { return 0, fmt.Errorf("write metrics: %w", err) } } diff --git a/internal/config/config.go b/internal/config/config.go index 1aeb295..d8869e5 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -60,7 +60,7 @@ var ( // Config represents a Lava configuration. type Config struct { // LavaVersion is the minimum required version of Lava. - LavaVersion string `yaml:"lava"` + LavaVersion *string `yaml:"lava"` // AgentConfig is the configuration of the vulcan-agent. AgentConfig AgentConfig `yaml:"agent"` @@ -123,7 +123,7 @@ func ParseFile(path string) (Config, error) { // validate validates the Lava configuration. func (c Config) validate() error { // Lava version validation. - if !semver.IsValid(c.LavaVersion) { + if !semver.IsValid(Get(c.LavaVersion)) { return ErrInvalidLavaVersion } @@ -148,7 +148,7 @@ func (c Config) validate() error { // the specified version. An invalid semantic version string is // considered incompatible. func (c Config) IsCompatible(v string) bool { - return semver.Compare(v, c.LavaVersion) >= 0 + return semver.Compare(v, Get(c.LavaVersion)) >= 0 } // AgentConfig is the configuration passed to the vulcan-agent. @@ -158,7 +158,7 @@ type AgentConfig struct { // Parallel is the maximum number of checks that can run in // parallel. - Parallel int `yaml:"parallel"` + Parallel *int `yaml:"parallel"` // Vars is the environment variables required by the Vulcan // checktypes. @@ -183,7 +183,7 @@ type ReportConfig struct { Format *OutputFormat `yaml:"format"` // OutputFile is the path of the output file. - OutputFile string `yaml:"output"` + OutputFile *string `yaml:"output"` // Exclusions is a list of findings that will be ignored. For // instance, accepted risks, false positives, etc. @@ -191,12 +191,12 @@ type ReportConfig struct { // ErrorOnStaleExclusions specifies whether Lava should exit // with error when stale exclusions are detected. - ErrorOnStaleExclusions bool `yaml:"errorOnStaleExclusions"` + ErrorOnStaleExclusions *bool `yaml:"errorOnStaleExclusions"` // Metrics is the file where the metrics will be written. // If Metrics is an empty string or not specified in the yaml file, then // the metrics report is not saved. - Metrics string `yaml:"metrics"` + Metrics *string `yaml:"metrics"` } // Target represents the target of a scan. @@ -446,3 +446,13 @@ func (ed ExpirationDate) MarshalText() (text []byte, err error) { func (ed ExpirationDate) String() string { return ed.Format(ExpirationDateLayout) } + +// Get returns the value of a pointer or the zero value if the pointer +// is nil. +func Get[T any](p *T) T { + var v T + if p != nil { + v = *p + } + return v +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index afd8922..0bd5d0a 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -29,7 +29,7 @@ func TestParse(t *testing.T) { name: "valid", file: "testdata/valid.yaml", want: Config{ - LavaVersion: "v1.0.0", + LavaVersion: ptr("v1.0.0"), ChecktypeURLs: []string{ "checktypes.json", }, @@ -45,7 +45,7 @@ func TestParse(t *testing.T) { name: "valid env", file: "testdata/valid_env.yaml", want: Config{ - LavaVersion: "v1.0.0", + LavaVersion: ptr("v1.0.0"), ChecktypeURLs: []string{ "checktypes.json", }, @@ -65,7 +65,7 @@ func TestParse(t *testing.T) { name: "invalid env", file: "testdata/invalid_env.yaml", want: Config{ - LavaVersion: "v1.0.0", + LavaVersion: ptr("v1.0.0"), ChecktypeURLs: []string{ "checktypes.json", }, @@ -117,7 +117,7 @@ func TestParse(t *testing.T) { name: "critical severity", file: "testdata/critical_severity.yaml", want: Config{ - LavaVersion: "v1.0.0", + LavaVersion: ptr("v1.0.0"), ChecktypeURLs: []string{ "checktypes.json", }, @@ -142,7 +142,7 @@ func TestParse(t *testing.T) { name: "low show", file: "testdata/low_show.yaml", want: Config{ - LavaVersion: "v1.0.0", + LavaVersion: ptr("v1.0.0"), ChecktypeURLs: []string{ "checktypes.json", }, @@ -161,7 +161,7 @@ func TestParse(t *testing.T) { name: "never pull policy", file: "testdata/never_pull_policy.yaml", want: Config{ - LavaVersion: "v1.0.0", + LavaVersion: ptr("v1.0.0"), ChecktypeURLs: []string{ "checktypes.json", }, @@ -192,7 +192,7 @@ func TestParse(t *testing.T) { name: "JSON output format", file: "testdata/json_output_format.yaml", want: Config{ - LavaVersion: "v1.0.0", + LavaVersion: ptr("v1.0.0"), ChecktypeURLs: []string{ "checktypes.json", }, @@ -217,7 +217,7 @@ func TestParse(t *testing.T) { name: "debug log level", file: "testdata/debug_log_level.yaml", want: Config{ - LavaVersion: "v1.0.0", + LavaVersion: ptr("v1.0.0"), ChecktypeURLs: []string{ "checktypes.json", }, @@ -240,7 +240,7 @@ func TestParse(t *testing.T) { name: "valid expiration date", file: "testdata/valid_expiration_date.yaml", want: Config{ - LavaVersion: "v1.0.0", + LavaVersion: ptr("v1.0.0"), ChecktypeURLs: []string{ "checktypes.json", }, @@ -251,7 +251,7 @@ func TestParse(t *testing.T) { }, }, ReportConfig: ReportConfig{ - OutputFile: "", + OutputFile: nil, Exclusions: []Exclusion{ { Summary: "Secret Leaked in Git Repository", @@ -309,31 +309,31 @@ func TestConfig_IsCompatible(t *testing.T) { }{ { name: "same version", - cfg: Config{LavaVersion: "v1.0.0"}, + cfg: Config{LavaVersion: ptr("v1.0.0")}, v: "v1.0.0", want: true, }, { name: "lower version", - cfg: Config{LavaVersion: "v1.1.0"}, + cfg: Config{LavaVersion: ptr("v1.1.0")}, v: "1.0.0", want: false, }, { name: "higher version", - cfg: Config{LavaVersion: "v1.0.0"}, + cfg: Config{LavaVersion: ptr("v1.0.0")}, v: "v1.1.0", want: true, }, { name: "pre-release", - cfg: Config{LavaVersion: "v0.0.0"}, + cfg: Config{LavaVersion: ptr("v0.0.0")}, v: "v0.0.0-20231216173526-1150d51c5272", want: false, }, { name: "invalid version", - cfg: Config{LavaVersion: "v1.0.0"}, + cfg: Config{LavaVersion: ptr("v1.0.0")}, v: "invalid", want: false, }, diff --git a/internal/engine/engine.go b/internal/engine/engine.go index c8d6cfe..2abe4b1 100644 --- a/internal/engine/engine.go +++ b/internal/engine/engine.go @@ -91,7 +91,7 @@ func newAgentConfig(cli containers.DockerdClient, cfg config.AgentConfig) (agent return agentconfig.Config{}, fmt.Errorf("get gateway interface address: %w", err) } - parallel := cfg.Parallel + parallel := config.Get(cfg.Parallel) if parallel == 0 { parallel = 1 } @@ -109,10 +109,6 @@ func newAgentConfig(cli containers.DockerdClient, cfg config.AgentConfig) (agent Pass: r.Password, }) } - var pullPolicy agentconfig.PullPolicy - if cfg.PullPolicy != nil { - pullPolicy = *cfg.PullPolicy - } acfg := agentconfig.Config{ Agent: agentconfig.AgentConfig{ @@ -131,7 +127,7 @@ func newAgentConfig(cli containers.DockerdClient, cfg config.AgentConfig) (agent Runtime: agentconfig.RuntimeConfig{ Docker: agentconfig.DockerConfig{ Registry: agentconfig.RegistryConfig{ - PullPolicy: pullPolicy, + PullPolicy: config.Get(cfg.PullPolicy), BackoffMaxRetries: 5, BackoffInterval: 5, BackoffJitterFactor: 0.5, diff --git a/internal/report/report.go b/internal/report/report.go index f0b0b62..cb8c542 100644 --- a/internal/report/report.go +++ b/internal/report/report.go @@ -38,11 +38,7 @@ var timeNow = time.Now // NewWriter creates a new instance of a report writer. func NewWriter(cfg config.ReportConfig) (Writer, error) { var prn printer - var format config.OutputFormat - if cfg.Format != nil { - format = *cfg.Format - } - switch format { + switch config.Get(cfg.Format) { case config.OutputFormatHuman: prn = humanPrinter{} case config.OutputFormatJSON: @@ -53,8 +49,8 @@ func NewWriter(cfg config.ReportConfig) (Writer, error) { w := os.Stdout isStdout := true - if cfg.OutputFile != "" { - f, err := os.Create(cfg.OutputFile) + if config.Get(cfg.OutputFile) != "" { + f, err := os.Create(config.Get(cfg.OutputFile)) if err != nil { return Writer{}, fmt.Errorf("create file: %w", err) } @@ -62,25 +58,21 @@ func NewWriter(cfg config.ReportConfig) (Writer, error) { isStdout = false } - var severity config.Severity - if cfg.Severity != nil { - severity = *cfg.Severity - } var showSeverity config.Severity if cfg.ShowSeverity != nil { showSeverity = *cfg.ShowSeverity } else { - showSeverity = severity + showSeverity = config.Get(cfg.Severity) } return Writer{ prn: prn, w: w, isStdout: isStdout, - minSeverity: severity, + minSeverity: config.Get(cfg.Severity), showSeverity: showSeverity, exclusions: cfg.Exclusions, - errorOnStaleExclusions: cfg.ErrorOnStaleExclusions, + errorOnStaleExclusions: config.Get(cfg.ErrorOnStaleExclusions), }, nil } diff --git a/internal/report/report_test.go b/internal/report/report_test.go index 959abf5..a019704 100644 --- a/internal/report/report_test.go +++ b/internal/report/report_test.go @@ -270,7 +270,7 @@ func TestWriter_calculateExitCode(t *testing.T) { }, rConfig: config.ReportConfig{ Severity: ptr(config.SeverityHigh), - ErrorOnStaleExclusions: true, + ErrorOnStaleExclusions: ptr(true), Exclusions: []config.Exclusion{ { Summary: "Unused exclusion", @@ -1574,7 +1574,7 @@ func TestNewWriter_OutputFile(t *testing.T) { }, rConfig: config.ReportConfig{ Severity: ptr(config.SeverityInfo), - OutputFile: "test.json", + OutputFile: ptr("test.json"), Format: ptr(config.OutputFormatJSON), }, wantExitCode: ExitCodeInfo, @@ -1589,7 +1589,7 @@ func TestNewWriter_OutputFile(t *testing.T) { } defer os.RemoveAll(tmpPath) - tt.rConfig.OutputFile = path.Join(tmpPath, tt.rConfig.OutputFile) + tt.rConfig.OutputFile = ptr(path.Join(tmpPath, config.Get(tt.rConfig.OutputFile))) writer, err := NewWriter(tt.rConfig) if err != nil { t.Fatalf("unable to create a report writer: %v", err) @@ -1603,7 +1603,7 @@ func TestNewWriter_OutputFile(t *testing.T) { t.Errorf("unexpected error value: got: %d, want: %d", gotExitCode, tt.wantExitCode) } - if _, err = os.Stat(tt.rConfig.OutputFile); err != nil { + if _, err = os.Stat(config.Get(tt.rConfig.OutputFile)); err != nil { t.Fatalf("unexpected error value: %v", err) } })