Skip to content

Commit

Permalink
Refactor with duplicated functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jesusfcr committed Jan 17, 2025
1 parent 61d2d45 commit f0dcdc8
Show file tree
Hide file tree
Showing 3 changed files with 363 additions and 239 deletions.
186 changes: 112 additions & 74 deletions cmd/vulcan-aws-alerts/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,14 @@ import (
"io"
"net/http"
"os"
"time"

check "github.com/adevinta/vulcan-check-sdk"
checkstate "github.com/adevinta/vulcan-check-sdk/state"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/sirupsen/logrus"

check "github.com/adevinta/vulcan-check-sdk"
checkstate "github.com/adevinta/vulcan-check-sdk/state"
"github.com/aws/aws-sdk-go/aws/arn"
)

Expand All @@ -32,58 +30,22 @@ func main() {
run := func(ctx context.Context, target, assetType, optJSON string, state checkstate.State) error {
logger := check.NewCheckLog(checkName)

if target == "" {
return fmt.Errorf("check target missing")
}

assumeRoleEndpoint := os.Getenv("VULCAN_ASSUME_ROLE_ENDPOINT")
role := os.Getenv("ROLE_NAME")

parsedARN, err := arn.Parse(target)
if err != nil {
return err
return fmt.Errorf("unable to parse ARN: %w", err)
}
assumeRoleEndpoint := os.Getenv("VULCAN_ASSUME_ROLE_ENDPOINT")
roleName := os.Getenv("VULCAN_ASSUME_ROLE_ENDPOINT")

var cfg aws.Config
var creds aws.Credentials
if assumeRoleEndpoint != "" {
c, err := getCredentials(assumeRoleEndpoint, parsedARN.AccountID, role, logger)
if err != nil {
if errors.Is(err, errNoCredentials) {
return checkstate.ErrAssetUnreachable
}
return err
}
creds = *c
if assumeRoleEndpoint == "" {
cfg, err = GetAwsConfig(target, roleName, 3600)
} else {
// try to access with the default credentials
// TODO: Review when the error should be an checkstate.ErrAssetUnreachable (INCONCLUSIVE)
cfg, err = config.LoadDefaultConfig(context.Background(), config.WithRegion("us-east-1"))
if err != nil {
return fmt.Errorf("unable to create AWS config: %w", err)
}
stsSvc := sts.NewFromConfig(cfg)
roleArn := fmt.Sprintf("arn:aws:iam::%s:role/%s", parsedARN.AccountID, role)
prov := stscreds.NewAssumeRoleProvider(stsSvc, roleArn)
creds, err = prov.Retrieve(context.Background())
if err != nil {
return fmt.Errorf("unable to assume role: %w", err)
}
}
cfg, err = GetAwsConfigWithVulcanAssumeRole(assumeRoleEndpoint, parsedARN.AccountID, roleName, 3600)

credsProvider := credentials.NewStaticCredentialsProvider(creds.AccessKeyID, creds.SecretAccessKey, creds.SessionToken)
cfg, err = config.LoadDefaultConfig(context.Background(),
config.WithRegion("us-east-1"),
config.WithCredentialsProvider(credsProvider),
)
if err != nil {
return fmt.Errorf("unable to create AWS config: %w", err)
}

// Validate that the account id in the target ARN matches the account id in the credentials
if req, err := sts.NewFromConfig(cfg).GetCallerIdentity(context.TODO(), &sts.GetCallerIdentityInput{}); err != nil {
return fmt.Errorf("unable to get caller identity: %w", err)
} else if *req.Account != parsedARN.AccountID {
return fmt.Errorf("account id in target ARN does not match the account id in the credentials (target ARN: %s, credentials account id: %s)", parsedARN.AccountID, *req.Account)
if err != nil {
return fmt.Errorf("unable to get AWS config: %w", err)
}

return caCertificateRotation(logger, cfg, parsedARN.AccountID, state)
Expand All @@ -92,53 +54,129 @@ func main() {
c.RunAndServe()
}

// AssumeRoleResponse represent a response from vulcan-assume-role
type AssumeRoleResponse struct {
AccessKey string `json:"access_key"`
SecretAccessKey string `json:"secret_access_key"`
SessionToken string `json:"session_token"`
}

var errNoCredentials = errors.New("unable to decode credentials")

func getCredentials(url string, accountID, role string, logger *logrus.Entry) (*aws.Credentials, error) {
m := map[string]any{"account_id": accountID, "duration": 3600}
if role != "" {
m["role"] = role
type VulcanAssumeRoleProvider struct {
URL string
AccountID string
Role string
Duration int
}

func NewVulcanAssumeRoleProvider(url, accountID, role string, duration int) *VulcanAssumeRoleProvider {
return &VulcanAssumeRoleProvider{
URL: url,
AccountID: accountID,
Role: role,
Duration: duration,
}
jsonBody, err := json.Marshal(m)
if err != nil {
return nil, fmt.Errorf("unable to marshal assume role request body for account %s: %w", accountID, err)
}

func (p *VulcanAssumeRoleProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
var emptyCreds aws.Credentials
m := map[string]any{"account_id": p.AccountID, "duration": p.Duration}
if p.Role != "" {
m["role"] = p.Role
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody))
expires := time.Now().Add(time.Second * time.Duration(p.Duration))
jsonBody, _ := json.Marshal(m) // nolint
req, err := http.NewRequest("POST", p.URL, bytes.NewBuffer(jsonBody))
if err != nil {
return nil, fmt.Errorf("unable to create request for the assume role service: %w", err)
return emptyCreds, fmt.Errorf("unable to create request for the assume role service: %w", err)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
logger.Errorf("cannot do request: %s", err.Error())
return nil, err
return emptyCreds, fmt.Errorf("cannot do request: %w", err)
}
defer resp.Body.Close() // nolint

type AssumeRoleResponse struct {
AccessKey string `json:"access_key"`
SecretAccessKey string `json:"secret_access_key"`
SessionToken string `json:"session_token"`
}
assumeRoleResponse := AssumeRoleResponse{}
buf, err := io.ReadAll(resp.Body)
if err != nil {
logger.Errorf("can not read request body %s", err.Error())
return nil, err
return emptyCreds, fmt.Errorf("can not read request body %w", err)
}

err = json.Unmarshal(buf, &assumeRoleResponse)
if err != nil {
logger.Errorf("Cannot decode request: %s", err.Error())
logger.Errorf("ResponseBody: %s", string(buf))
return nil, errNoCredentials
return emptyCreds, fmt.Errorf("cannot decode request: %s body: %s: %w", err.Error(), string(buf), errNoCredentials)
}
return &aws.Credentials{
return aws.Credentials{
Source: "VulcanAssumeRoleProvider",
AccessKeyID: assumeRoleResponse.AccessKey,
SecretAccessKey: assumeRoleResponse.SecretAccessKey,
SessionToken: assumeRoleResponse.SessionToken,
AccountID: p.AccountID,
CanExpire: true,
Expires: expires,
}, nil
}

func GetAwsConfig(accountArn, role string, duration int) (aws.Config, error) {
var cfg aws.Config
parsedARN, err := arn.Parse(accountArn)
if err != nil {
return cfg, err
}

cfg, err = config.LoadDefaultConfig(
context.Background(),
config.WithRegion("us-east-1"))
if err != nil {
return cfg, fmt.Errorf("unable to create default AWS config: %w", err)
}
if role != "" {
cfg, err = config.LoadDefaultConfig(
context.Background(),
config.WithRegion("us-east-1"),
config.WithCredentialsProvider(
stscreds.NewAssumeRoleProvider(
sts.NewFromConfig(cfg),
fmt.Sprintf("arn:aws:iam::%s:role/%s", parsedARN.AccountID, role),
func(o *stscreds.AssumeRoleOptions) {
if duration != 0 {
o.Duration = time.Duration(duration) * time.Second
}
},
)))
}
if err != nil {
return cfg, fmt.Errorf("unable to create AWS config: %w", err)
}
// Validate that the account id in the target ARN matches the account id in the credentials
if req, err := sts.NewFromConfig(cfg).GetCallerIdentity(context.TODO(), &sts.GetCallerIdentityInput{}); err != nil {
return cfg, fmt.Errorf("unable to get caller identity: %w", err)
} else if *req.Account != parsedARN.AccountID {
return cfg, fmt.Errorf("account id in target ARN does not match the account id in the credentials (target ARN: %s, credentials account id: %s)", parsedARN.AccountID, *req.Account)
}
return cfg, nil
}

func GetAwsConfigWithVulcanAssumeRole(assumeRoleEndpoint, accountArn, role string, duration int) (aws.Config, error) {
var cfg aws.Config
parsedARN, err := arn.Parse(accountArn)
if err != nil {
return cfg, err
}
cfg, err = config.LoadDefaultConfig(
context.Background(),
config.WithCredentialsProvider(
NewVulcanAssumeRoleProvider(assumeRoleEndpoint, parsedARN.AccountID, role, duration),
))
if err != nil {
return cfg, fmt.Errorf("unable to create AWS config: %w", err)
}
// Validate that the account id in the target ARN matches the account id in the credentials
if req, err := sts.NewFromConfig(cfg).GetCallerIdentity(context.TODO(), &sts.GetCallerIdentityInput{}); err != nil {
return cfg, fmt.Errorf("unable to get caller identity: %w", err)
} else if *req.Account != parsedARN.AccountID {
return cfg, fmt.Errorf("account id in target ARN does not match the account id in the credentials (target ARN: %s, credentials account id: %s)", parsedARN.AccountID, *req.Account)
}
return cfg, nil
}
Loading

0 comments on commit f0dcdc8

Please sign in to comment.