diff --git a/src/sessionmanagerplugin/session/session.go b/src/sessionmanagerplugin/session/session.go index 5f708ecf..6b1300e2 100644 --- a/src/sessionmanagerplugin/session/session.go +++ b/src/sessionmanagerplugin/session/session.go @@ -20,6 +20,9 @@ import ( "fmt" "io" "os" + "os/exec" + "regexp" + "strings" "time" "github.com/aws/SSMCLI/src/config" @@ -191,6 +194,23 @@ func ValidateInputAndStartSession(args []string, out io.Writer) { session.SessionId = *startSessionOutput.SessionId session.StreamUrl = *startSessionOutput.StreamUrl + + if len(profile) > 0 { + // Lookup for a custom endpoint inside the profile + ssmmsgEndpointRes, err := exec.Command("aws", "configure", "get", "aws_ssmmessages_endpoint", "--profile", profile).Output() + if err == nil && ssmmsgEndpointRes != nil { + ssmmsgEndpoint := string(ssmmsgEndpointRes) + streamUrl := string(*startSessionOutput.StreamUrl) + // Cleanup Windows/Linux line feeds + ssmmsgEndpoint = strings.TrimSuffix(strings.TrimSuffix(strings.TrimSuffix(ssmmsgEndpoint, "\n"), "\r"), "\n") + if strings.HasSuffix(ssmmsgEndpoint, ".vpce.amazonaws.com") { + // Looks like a legit endpoint, patch the WSS url + fmt.Printf("Custom endpoint detected %s\n", ssmmsgEndpoint) + m := regexp.MustCompile(`ssmmessages.*amazonaws.com`) + session.StreamUrl = m.ReplaceAllString(streamUrl, ssmmsgEndpoint) + } + } + } session.TokenValue = *startSessionOutput.TokenValue session.Endpoint = ssmEndpoint session.ClientId = clientId