Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use aws-sdk v2 in the trusted advisor check #550

Merged
merged 10 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 103 additions & 87 deletions cmd/vulcan-aws-trusted-advisor/main.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
/*
Copyright 2019 Adevinta
*/

package main

import (
Expand All @@ -17,20 +16,19 @@ import (
"strings"
"time"

"github.com/aws/aws-sdk-go/aws/awserr"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/support"
"github.com/sirupsen/logrus"

check "github.com/adevinta/vulcan-check-sdk"
"github.com/adevinta/vulcan-check-sdk/helpers"
checkstate "github.com/adevinta/vulcan-check-sdk/state"
report "github.com/adevinta/vulcan-report"
"github.com/aws/aws-sdk-go/aws/arn"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/arn"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/iam"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/aws/aws-sdk-go-v2/service/support"
"github.com/sirupsen/logrus"
)

const (
Expand Down Expand Up @@ -107,39 +105,50 @@ func extractLinesFromHTML(htmlText string) []string {
return result
}

func scanAccount(opt options, target, assetType string, logger *logrus.Entry, state checkstate.State) error {
sess, err := session.NewSession(&aws.Config{
Region: aws.String("us-east-1"),
})
if err != nil {
return err
}

func scanAccount(opt options, target, _ string, logger *logrus.Entry, state checkstate.State) error {
assumeRoleEndpoint := os.Getenv("VULCAN_ASSUME_ROLE_ENDPOINT")
role := os.Getenv("ROLE_NAME")

isReachable, err := helpers.IsReachable(target, assetType,
helpers.NewAWSCreds(assumeRoleEndpoint, role))
if err != nil {
logger.Warnf("Can not check asset reachability: %v", err)
}
if !isReachable {
return checkstate.ErrAssetUnreachable
}

parsedARN, err := arn.Parse(target)
if err != nil {
return err
}
creds, err := getCredentials(assumeRoleEndpoint, parsedARN.AccountID, role, logger)
if err != nil {
return err
var cfg aws.Config
jesusfcr marked this conversation as resolved.
Show resolved Hide resolved
if assumeRoleEndpoint != "" {
creds, err := getCredentials(assumeRoleEndpoint, parsedARN.AccountID, role, logger)
if err != nil {
if errors.Is(err, errNoCredentials) {
return checkstate.ErrAssetUnreachable
}
return err
}
credsProvider := credentials.NewStaticCredentialsProvider(creds.AccessKeyID, creds.SecretAccessKey, creds.SessionToken)
cfg, err = config.LoadDefaultConfig(context.Background(),
config.WithRegion("us-east-1"),
config.WithCredentialsProvider(credsProvider),
)
if err != nil {
return fmt.Errorf("unable to create AWS config: %w", err)
}
} else {
// try to access with the default credentials
cfg, err = config.LoadDefaultConfig(context.Background(), config.WithRegion("us-east-1"))
if err != nil {
return fmt.Errorf("unable to create AWS config: %w", err)
}
}

s := support.New(sess, &aws.Config{Credentials: creds})
// Validate that the account id in the target ARN matches the account id in the credentials
if req, err := sts.NewFromConfig(cfg).GetCallerIdentity(context.TODO(), &sts.GetCallerIdentityInput{}); err != nil {
return fmt.Errorf("unable to get caller identity: %w", err)
} else if *req.Account != parsedARN.AccountID {
return fmt.Errorf("account id in target ARN does not match the account id in the credentials (target ARN: %s, credentials account id: %s)", parsedARN.AccountID, *req.Account)
}

s := support.NewFromConfig(cfg)
// Retrieve checks list
checks, err := s.DescribeTrustedAdvisorChecks(
context.TODO(),
&support.DescribeTrustedAdvisorChecksInput{
Language: aws.String("en"),
})
Expand All @@ -161,16 +170,18 @@ func scanAccount(opt options, target, assetType string, logger *logrus.Entry, st
continue
}
checkIds = append(checkIds, check.Id)
refreshed, err := s.RefreshTrustedAdvisorCheck(&support.RefreshTrustedAdvisorCheckInput{CheckId: check.Id})
refreshed, err := s.RefreshTrustedAdvisorCheck(context.Background(), &support.RefreshTrustedAdvisorCheckInput{CheckId: check.Id})
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
if awsErr.Code() == "InvalidParameterValueException" {
logger.Printf("check '%s' is not refreshable\n", *check.Name)
continue
}
// Haven't found a more elegant way to check for an
// InvalidParameterValueException. This error type is not defined in the
// support/types package as it is for other services.
if strings.Contains(err.Error(), "InvalidParameterValueException") {
logger.Printf("check '%s' is not refreshable\n", *check.Name)
continue
}
return err
}

logger.Printf("check '%s' is refreshed with status: '%s'\n", *check.Name, *refreshed.Status.Status)
if *refreshed.Status.Status == "enqueued" {
enqueued++
Expand All @@ -190,16 +201,18 @@ func scanAccount(opt options, target, assetType string, logger *logrus.Entry, st
break LOOP
default:
checkStatus, err := s.DescribeTrustedAdvisorCheckRefreshStatuses(
context.Background(),
&support.DescribeTrustedAdvisorCheckRefreshStatusesInput{
CheckIds: checkIds,
},
)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
if awsErr.Code() != "InvalidParameterValueException" {
return err
}
}
// Haven't found a more elegant way to check for an
// InvalidParameterValueException. This error type is not
// defined in the support/types package as it is for other
// services.
if err != nil && !strings.Contains(err.Error(), "InvalidParameterValueException") {
return fmt.Errorf("unable to check the refresh statuses: %w", err)

}
var pending bool
for _, cs := range checkStatus.Statuses {
Expand Down Expand Up @@ -242,34 +255,32 @@ func scanAccount(opt options, target, assetType string, logger *logrus.Entry, st

var checkSummaries *support.DescribeTrustedAdvisorCheckSummariesOutput
checkSummaries, err = s.DescribeTrustedAdvisorCheckSummaries(
&support.DescribeTrustedAdvisorCheckSummariesInput{
context.Background(), &support.DescribeTrustedAdvisorCheckSummariesInput{
CheckIds: []*string{v.Id}})
if err != nil {
return err
}

for _, summary := range checkSummaries.Summaries {
// Only process summaries that has flagged resources
if summary.HasFlaggedResources == nil {
// Only process summaries that has flagged resources.
if !summary.HasFlaggedResources {
continue
}

if summary.HasFlaggedResources != nil && *summary.HasFlaggedResources == false {
continue
}

description := ""
action := ""
recommendedActions := []string{}
additionalResources := []string{}

// Avoid nil pointer dereference when reading *v.Description
// description, recommendedActions and additionalResources will be
// considered as empty
// considered empty.
if v.Description != nil {
iRecommendedAction := strings.Index(*v.Description, tagRecommendedAction)
if iRecommendedAction < 0 {
// No recommended actions
continue
}
iAdditionalResources := strings.Index(*v.Description, tagAdditionalResources)
description = string(*v.Description)[:iRecommendedAction]

// Extract recommendedActions
if iAdditionalResources >= iRecommendedAction+len(tagRecommendedAction) {
recommendedActions = extractLinesFromHTML(string(*v.Description)[iRecommendedAction+len(tagRecommendedAction) : iAdditionalResources])
Expand All @@ -287,26 +298,20 @@ func scanAccount(opt options, target, assetType string, logger *logrus.Entry, st
}

var checkResults *support.DescribeTrustedAdvisorCheckResultOutput
checkResults, err = s.DescribeTrustedAdvisorCheckResult(&support.DescribeTrustedAdvisorCheckResultInput{CheckId: v.Id})
checkResults, err = s.DescribeTrustedAdvisorCheckResult(context.Background(), &support.DescribeTrustedAdvisorCheckResultInput{CheckId: v.Id})
if err != nil {
return err
}

for _, fr := range checkResults.Result.FlaggedResources {
// Unable to retrieve flagged resource information
if fr == nil {
logger.Warnf("result with CheckID: %s does not contain flagged resource information", *checkResults.Result.CheckId)
continue
}
// PTVUL-860
// Ignore resources that have been marked as supressed/excluded
if *fr.IsSuppressed {
// Ignore resources that have been marked as suppressed/excluded
if fr.IsSuppressed {
logger.Debugf("resource with ResourceID: %s have been marked as excluded", *fr.ResourceId)
continue
}
// Get the alias of the account only if we did not get previously.
if alias == nil {
res, err := accountAlias(creds)
res, err := accountAlias(cfg)
if err != nil {
return err
}
Expand Down Expand Up @@ -359,10 +364,13 @@ func scanAccount(opt options, target, assetType string, logger *logrus.Entry, st
if v.Name != nil {
summary = "AWS " + *v.Name
}

resourceID := ""
if fr.ResourceId != nil {
resourceID = *fr.ResourceId
}
vuln := report.Vulnerability{
Summary: summary,
Description: description,
Description: action,
Score: score,
// AWS Trusted Advisor provides already an ID generated by
// them, that seems the best option to indicate which is
Expand All @@ -371,7 +379,7 @@ func scanAccount(opt options, target, assetType string, logger *logrus.Entry, st
// therefore we are using a set of the metadata values
// provided by their checks in the AffectedResourceString
// attribute.
AffectedResource: aws.StringValue(fr.ResourceId),
AffectedResource: resourceID,
AffectedResourceString: affectedResourceStr,
Labels: []string{"issue", "aws"},
Resources: []report.ResourcesGroup{occurrences},
Expand All @@ -397,57 +405,65 @@ type AssumeRoleResponse struct {
SessionToken string `json:"session_token"`
}

func getCredentials(url string, accountID, role string, logger *logrus.Entry) (*credentials.Credentials, error) {
m := map[string]string{"account_id": accountID}
var errNoCredentials = errors.New("unable to decode credentials")

func getCredentials(url string, accountID, role string, logger *logrus.Entry) (*aws.Credentials, error) {
m := map[string]any{"account_id": accountID, "duration": 3600}
if role != "" {
m["role"] = role
}
jsonBody, err := json.Marshal(m)
if err != nil {
return nil, fmt.Errorf("unable to marshal assume role request body for account %s: %w", accountID, err)
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody))
if err != nil {
return nil, fmt.Errorf("unable to create request for the assume role service: %w", err)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
logger.Errorf("cannot do request: %s", err.Error())
return nil, err
}
defer resp.Body.Close()
defer resp.Body.Close() // nolint

assumeRoleResponse := AssumeRoleResponse{}
buf, err := io.ReadAll(resp.Body)
if err != nil {
logger.Errorf("Cannot read request body %s", err.Error())
logger.Errorf("can not read request body %s", err.Error())
return nil, err
}

err = json.Unmarshal(buf, &assumeRoleResponse)
if err != nil {
logger.Errorf("Cannot decode request %s", err.Error())
logger.Errorf("RequestBody: %s", string(buf))
return nil, err
logger.Errorf("Cannot decode request: %s", err.Error())
logger.Errorf("ResponseBody: %s", string(buf))
return nil, errNoCredentials
}

return credentials.NewStaticCredentials(
assumeRoleResponse.AccessKey,
assumeRoleResponse.SecretAccessKey,
assumeRoleResponse.SessionToken), nil
return &aws.Credentials{
AccessKeyID: assumeRoleResponse.AccessKey,
SecretAccessKey: assumeRoleResponse.SecretAccessKey,
SessionToken: assumeRoleResponse.SessionToken,
}, nil
}

// accountAlias gets one of the current aliases of the account that the
// credentials passed belong to.
func accountAlias(creds *credentials.Credentials) (string, error) {
svc := iam.New(session.New(&aws.Config{Credentials: creds}))
resp, err := svc.ListAccountAliases(&iam.ListAccountAliasesInput{})
func accountAlias(cfg aws.Config) (string, error) {
svc := iam.NewFromConfig(cfg)
resp, err := svc.ListAccountAliases(context.Background(), &iam.ListAccountAliasesInput{})
if err != nil {
return "", err
}
if len(resp.AccountAliases) == 0 {
// No aliases found for the aws account.
return "", nil
}
a := resp.AccountAliases[0]
if a == nil {
return "", errors.New("unexpected nil getting aliases for aws account")
if len(resp.AccountAliases) < 1 {
return "", errors.New("no result getting aliases for aws account")
}
return *a, nil
a := resp.AccountAliases[0]
return a, nil
}
2 changes: 1 addition & 1 deletion cmd/vulcan-aws-trusted-advisor/manifest.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Description = "Runs an AWS Trusted Advisor check against an AWS account"
AssetTypes = ["AWSAccount"]
RequiredVars = ["VULCAN_ASSUME_ROLE_ENDPOINT", "ROLE_NAME"]
RequiredVars = ["VULCAN_ASSUME_ROLE_ENDPOINT", "ROLE_NAME", "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY" ,"AWS_SESSION_TOKEN"]
Options = '{"refresh_timeout": 60}'
Loading