Skip to content

Commit

Permalink
Use aws-sdk v2 in the trusted advisor check (#550)
Browse files Browse the repository at this point in the history
  • Loading branch information
manelmontilla authored Jul 15, 2024
1 parent f332445 commit d048c22
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 104 deletions.
190 changes: 103 additions & 87 deletions cmd/vulcan-aws-trusted-advisor/main.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
/*
Copyright 2019 Adevinta
*/

package main

import (
Expand All @@ -17,20 +16,19 @@ import (
"strings"
"time"

"github.com/aws/aws-sdk-go/aws/awserr"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/support"
"github.com/sirupsen/logrus"

check "github.com/adevinta/vulcan-check-sdk"
"github.com/adevinta/vulcan-check-sdk/helpers"
checkstate "github.com/adevinta/vulcan-check-sdk/state"
report "github.com/adevinta/vulcan-report"
"github.com/aws/aws-sdk-go/aws/arn"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/arn"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/iam"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/aws/aws-sdk-go-v2/service/support"
"github.com/sirupsen/logrus"
)

const (
Expand Down Expand Up @@ -107,39 +105,50 @@ func extractLinesFromHTML(htmlText string) []string {
return result
}

func scanAccount(opt options, target, assetType string, logger *logrus.Entry, state checkstate.State) error {
sess, err := session.NewSession(&aws.Config{
Region: aws.String("us-east-1"),
})
if err != nil {
return err
}

func scanAccount(opt options, target, _ string, logger *logrus.Entry, state checkstate.State) error {
assumeRoleEndpoint := os.Getenv("VULCAN_ASSUME_ROLE_ENDPOINT")
role := os.Getenv("ROLE_NAME")

isReachable, err := helpers.IsReachable(target, assetType,
helpers.NewAWSCreds(assumeRoleEndpoint, role))
if err != nil {
logger.Warnf("Can not check asset reachability: %v", err)
}
if !isReachable {
return checkstate.ErrAssetUnreachable
}

parsedARN, err := arn.Parse(target)
if err != nil {
return err
}
creds, err := getCredentials(assumeRoleEndpoint, parsedARN.AccountID, role, logger)
if err != nil {
return err
var cfg aws.Config
if assumeRoleEndpoint != "" {
creds, err := getCredentials(assumeRoleEndpoint, parsedARN.AccountID, role, logger)
if err != nil {
if errors.Is(err, errNoCredentials) {
return checkstate.ErrAssetUnreachable
}
return err
}
credsProvider := credentials.NewStaticCredentialsProvider(creds.AccessKeyID, creds.SecretAccessKey, creds.SessionToken)
cfg, err = config.LoadDefaultConfig(context.Background(),
config.WithRegion("us-east-1"),
config.WithCredentialsProvider(credsProvider),
)
if err != nil {
return fmt.Errorf("unable to create AWS config: %w", err)
}
} else {
// try to access with the default credentials
cfg, err = config.LoadDefaultConfig(context.Background(), config.WithRegion("us-east-1"))
if err != nil {
return fmt.Errorf("unable to create AWS config: %w", err)
}
}

s := support.New(sess, &aws.Config{Credentials: creds})
// Validate that the account id in the target ARN matches the account id in the credentials
if req, err := sts.NewFromConfig(cfg).GetCallerIdentity(context.TODO(), &sts.GetCallerIdentityInput{}); err != nil {
return fmt.Errorf("unable to get caller identity: %w", err)
} else if *req.Account != parsedARN.AccountID {
return fmt.Errorf("account id in target ARN does not match the account id in the credentials (target ARN: %s, credentials account id: %s)", parsedARN.AccountID, *req.Account)
}

s := support.NewFromConfig(cfg)
// Retrieve checks list
checks, err := s.DescribeTrustedAdvisorChecks(
context.TODO(),
&support.DescribeTrustedAdvisorChecksInput{
Language: aws.String("en"),
})
Expand All @@ -161,16 +170,18 @@ func scanAccount(opt options, target, assetType string, logger *logrus.Entry, st
continue
}
checkIds = append(checkIds, check.Id)
refreshed, err := s.RefreshTrustedAdvisorCheck(&support.RefreshTrustedAdvisorCheckInput{CheckId: check.Id})
refreshed, err := s.RefreshTrustedAdvisorCheck(context.Background(), &support.RefreshTrustedAdvisorCheckInput{CheckId: check.Id})
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
if awsErr.Code() == "InvalidParameterValueException" {
logger.Printf("check '%s' is not refreshable\n", *check.Name)
continue
}
// Haven't found a more elegant way to check for an
// InvalidParameterValueException. This error type is not defined in the
// support/types package as it is for other services.
if strings.Contains(err.Error(), "InvalidParameterValueException") {
logger.Printf("check '%s' is not refreshable\n", *check.Name)
continue
}
return err
}

logger.Printf("check '%s' is refreshed with status: '%s'\n", *check.Name, *refreshed.Status.Status)
if *refreshed.Status.Status == "enqueued" {
enqueued++
Expand All @@ -190,16 +201,18 @@ func scanAccount(opt options, target, assetType string, logger *logrus.Entry, st
break LOOP
default:
checkStatus, err := s.DescribeTrustedAdvisorCheckRefreshStatuses(
context.Background(),
&support.DescribeTrustedAdvisorCheckRefreshStatusesInput{
CheckIds: checkIds,
},
)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
if awsErr.Code() != "InvalidParameterValueException" {
return err
}
}
// Haven't found a more elegant way to check for an
// InvalidParameterValueException. This error type is not
// defined in the support/types package as it is for other
// services.
if err != nil && !strings.Contains(err.Error(), "InvalidParameterValueException") {
return fmt.Errorf("unable to check the refresh statuses: %w", err)

}
var pending bool
for _, cs := range checkStatus.Statuses {
Expand Down Expand Up @@ -242,34 +255,32 @@ func scanAccount(opt options, target, assetType string, logger *logrus.Entry, st

var checkSummaries *support.DescribeTrustedAdvisorCheckSummariesOutput
checkSummaries, err = s.DescribeTrustedAdvisorCheckSummaries(
&support.DescribeTrustedAdvisorCheckSummariesInput{
context.Background(), &support.DescribeTrustedAdvisorCheckSummariesInput{
CheckIds: []*string{v.Id}})
if err != nil {
return err
}

for _, summary := range checkSummaries.Summaries {
// Only process summaries that has flagged resources
if summary.HasFlaggedResources == nil {
// Only process summaries that has flagged resources.
if !summary.HasFlaggedResources {
continue
}

if summary.HasFlaggedResources != nil && *summary.HasFlaggedResources == false {
continue
}

description := ""
action := ""
recommendedActions := []string{}
additionalResources := []string{}

// Avoid nil pointer dereference when reading *v.Description
// description, recommendedActions and additionalResources will be
// considered as empty
// considered empty.
if v.Description != nil {
iRecommendedAction := strings.Index(*v.Description, tagRecommendedAction)
if iRecommendedAction < 0 {
// No recommended actions
continue
}
iAdditionalResources := strings.Index(*v.Description, tagAdditionalResources)
description = string(*v.Description)[:iRecommendedAction]

// Extract recommendedActions
if iAdditionalResources >= iRecommendedAction+len(tagRecommendedAction) {
recommendedActions = extractLinesFromHTML(string(*v.Description)[iRecommendedAction+len(tagRecommendedAction) : iAdditionalResources])
Expand All @@ -287,26 +298,20 @@ func scanAccount(opt options, target, assetType string, logger *logrus.Entry, st
}

var checkResults *support.DescribeTrustedAdvisorCheckResultOutput
checkResults, err = s.DescribeTrustedAdvisorCheckResult(&support.DescribeTrustedAdvisorCheckResultInput{CheckId: v.Id})
checkResults, err = s.DescribeTrustedAdvisorCheckResult(context.Background(), &support.DescribeTrustedAdvisorCheckResultInput{CheckId: v.Id})
if err != nil {
return err
}

for _, fr := range checkResults.Result.FlaggedResources {
// Unable to retrieve flagged resource information
if fr == nil {
logger.Warnf("result with CheckID: %s does not contain flagged resource information", *checkResults.Result.CheckId)
continue
}
// PTVUL-860
// Ignore resources that have been marked as supressed/excluded
if *fr.IsSuppressed {
// Ignore resources that have been marked as suppressed/excluded
if fr.IsSuppressed {
logger.Debugf("resource with ResourceID: %s have been marked as excluded", *fr.ResourceId)
continue
}
// Get the alias of the account only if we did not get previously.
if alias == nil {
res, err := accountAlias(creds)
res, err := accountAlias(cfg)
if err != nil {
return err
}
Expand Down Expand Up @@ -359,10 +364,13 @@ func scanAccount(opt options, target, assetType string, logger *logrus.Entry, st
if v.Name != nil {
summary = "AWS " + *v.Name
}

resourceID := ""
if fr.ResourceId != nil {
resourceID = *fr.ResourceId
}
vuln := report.Vulnerability{
Summary: summary,
Description: description,
Description: action,
Score: score,
// AWS Trusted Advisor provides already an ID generated by
// them, that seems the best option to indicate which is
Expand All @@ -371,7 +379,7 @@ func scanAccount(opt options, target, assetType string, logger *logrus.Entry, st
// therefore we are using a set of the metadata values
// provided by their checks in the AffectedResourceString
// attribute.
AffectedResource: aws.StringValue(fr.ResourceId),
AffectedResource: resourceID,
AffectedResourceString: affectedResourceStr,
Labels: []string{"issue", "aws"},
Resources: []report.ResourcesGroup{occurrences},
Expand All @@ -397,57 +405,65 @@ type AssumeRoleResponse struct {
SessionToken string `json:"session_token"`
}

func getCredentials(url string, accountID, role string, logger *logrus.Entry) (*credentials.Credentials, error) {
m := map[string]string{"account_id": accountID}
var errNoCredentials = errors.New("unable to decode credentials")

func getCredentials(url string, accountID, role string, logger *logrus.Entry) (*aws.Credentials, error) {
m := map[string]any{"account_id": accountID, "duration": 3600}
if role != "" {
m["role"] = role
}
jsonBody, err := json.Marshal(m)
if err != nil {
return nil, fmt.Errorf("unable to marshal assume role request body for account %s: %w", accountID, err)
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody))
if err != nil {
return nil, fmt.Errorf("unable to create request for the assume role service: %w", err)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
logger.Errorf("cannot do request: %s", err.Error())
return nil, err
}
defer resp.Body.Close()
defer resp.Body.Close() // nolint

assumeRoleResponse := AssumeRoleResponse{}
buf, err := io.ReadAll(resp.Body)
if err != nil {
logger.Errorf("Cannot read request body %s", err.Error())
logger.Errorf("can not read request body %s", err.Error())
return nil, err
}

err = json.Unmarshal(buf, &assumeRoleResponse)
if err != nil {
logger.Errorf("Cannot decode request %s", err.Error())
logger.Errorf("RequestBody: %s", string(buf))
return nil, err
logger.Errorf("Cannot decode request: %s", err.Error())
logger.Errorf("ResponseBody: %s", string(buf))
return nil, errNoCredentials
}

return credentials.NewStaticCredentials(
assumeRoleResponse.AccessKey,
assumeRoleResponse.SecretAccessKey,
assumeRoleResponse.SessionToken), nil
return &aws.Credentials{
AccessKeyID: assumeRoleResponse.AccessKey,
SecretAccessKey: assumeRoleResponse.SecretAccessKey,
SessionToken: assumeRoleResponse.SessionToken,
}, nil
}

// accountAlias gets one of the current aliases of the account that the
// credentials passed belong to.
func accountAlias(creds *credentials.Credentials) (string, error) {
svc := iam.New(session.New(&aws.Config{Credentials: creds}))
resp, err := svc.ListAccountAliases(&iam.ListAccountAliasesInput{})
func accountAlias(cfg aws.Config) (string, error) {
svc := iam.NewFromConfig(cfg)
resp, err := svc.ListAccountAliases(context.Background(), &iam.ListAccountAliasesInput{})
if err != nil {
return "", err
}
if len(resp.AccountAliases) == 0 {
// No aliases found for the aws account.
return "", nil
}
a := resp.AccountAliases[0]
if a == nil {
return "", errors.New("unexpected nil getting aliases for aws account")
if len(resp.AccountAliases) < 1 {
return "", errors.New("no result getting aliases for aws account")
}
return *a, nil
a := resp.AccountAliases[0]
return a, nil
}
Loading

0 comments on commit d048c22

Please sign in to comment.