diff --git a/README.md b/README.md index e772c273..c665a2bf 100644 --- a/README.md +++ b/README.md @@ -308,6 +308,12 @@ You can use this by adding the following to your `~/.aws/credentials` file: credential_process = clisso get my-app --output credential_process ``` +Alternatively you can run the following command to configure all Apps for use with `credential_process`: + +```bash +clisso cp configure +``` + The AWS SDK does not cache any credentials obtained using `credential_process`. This means that every time you use the profile, Clisso will be called to obtain new credentials. If you want to cache the credentials, you can use the `--cache` flag. For example: ```ini @@ -323,6 +329,21 @@ global: enable: true ``` +#### Temporarily Disabling Credential Process Functionality + +Different processes on your system might continue using AWS Profiles configured for use with Clisso. To temporarily disable the `credential_process` functionality, you can use the `clisso cp` submenu. For example: + +```bash +clisso cp disable # to disable +clisso cp enable # to enable +clisso cp status # to check the status +``` + +If you disable the `credential_process` functionality, all refreshes will be disabled. While cached credentials will still be used, new credentials will not be fetched. This can be useful if you lock your computer with an active, e.g., VSCode session with CodeCommit. If you wouldn't disable the `credential_process` functionality, the VSCode would constantly trigger new credential requests to refresh the remote CodeCommit repository. + +If you want to check the status programmatically, you can use the exit code of the `clisso cp status` command. If the exit code is `0`, the `credential_process` functionality is enabled. If the exit code is `1`, the `credential_process` functionality is disabled. + + ### Storing the password in the key chain > WARNING: Storing the password without having MFA enabled is a security risk. It allows anyone diff --git a/aws/aws.go b/aws/aws.go index fc3b5bb2..37784f09 100644 --- a/aws/aws.go +++ b/aws/aws.go @@ -12,7 +12,6 @@ import ( "github.com/allcloud-io/clisso/log" "github.com/go-ini/ini" - "github.com/sirupsen/logrus" ) // Credentials represents a set of temporary credentials received from AWS STS @@ -33,6 +32,11 @@ type Profile struct { const expireKey = "aws_expiration" +const credentialProcessFormat = "clisso -o credential_process get %s" +const errCannotBeUsed = "Profile %s contains key %s, which indicates, it should not be used by clisso" +const infoProfileConfigured = "Profile %s is now configured for credential_process" +const infoProfileAlreadyConfigured = "Profile %s is already configured for credential_process" + func validateSection(cfg *ini.File, section string) error { // if it doesn't exist, we're good if cfg.Section(section) == nil { @@ -42,21 +46,60 @@ func validateSection(cfg *ini.File, section string) error { // it should not have any of source_profile, role_arn, mfa_serial, external_id, or credential_source for _, key := range []string{"source_profile", "role_arn", "mfa_serial", "external_id", "credential_source", "credential_process"} { if s.HasKey(key) { - log.Log.WithFields(logrus.Fields{ + log.WithFields(log.Fields{ "section": section, "key": key, }).Errorf("Profile contains key %s, which indicates, it should not be used by clisso", key) - return fmt.Errorf("profile %s contains key %s, which indicates, it should not be used by clisso", section, key) + return fmt.Errorf(errCannotBeUsed, section, key) } } return nil } +// SetCredentialProcess writes the credential_process config to an AWS CLI credentials file in the format required by the SDK +func SetCredentialProcess(filename string, section string) error { + log.WithFields(log.Fields{ + "filename": filename, + "section": section, + }).Debug("Writing credentials to file") + cfg, err := ini.LooseLoad(filename) + if err != nil { + return err + } + err = validateSection(cfg, section) + if err != nil { + if err.Error() == fmt.Sprintf(errCannotBeUsed, section, "credential_process") { + log.Infof(infoProfileAlreadyConfigured, section) + return nil + } + log.WithError(err).Errorf("Profile %s cannot be configured for credential_process", section) + return err + } + if cfg.HasSection(section) { + log.Tracef("Section %s exists and has passed validation, adding credential_process key to it", section) + } + + _, err = cfg.Section(section).NewKey("credential_process", fmt.Sprintf(credentialProcessFormat, section)) + if err != nil { + return err + } + // unset aws_secret_access_key, aws_access_key_id, aws_session_token, aws_expiration + for _, key := range []string{"aws_access_key_id", "aws_secret_access_key", "aws_session_token", expireKey} { + if cfg.Section(section).HasKey(key) { + log.Debugf("Removing key %s from profile %s", key, section) + cfg.Section(section).DeleteKey(key) + } + } + log.Infof("Profile %s is now configured for credential_process", section) + + return cfg.SaveTo(filename) +} + // OutputFile writes credentials to an AWS CLI credentials file // (https://docs.aws.amazon.com/cli/latest/userguide/cli-config-files.html). In addition, this // function removes expired temporary credentials from the credentials file. func OutputFile(c *Credentials, filename string, section string) error { - log.Log.WithFields(logrus.Fields{ + log.WithFields(log.Fields{ "filename": filename, "section": section, }).Debug("Writing credentials to file") @@ -69,7 +112,7 @@ func OutputFile(c *Credentials, filename string, section string) error { return err } if cfg.HasSection(section) { - log.Log.Tracef("Section %s exists and has passed validation, adding aws_access_key_id, aws_secret_access_key, aws_session_token, %s keys to it", section, expireKey) + log.Tracef("Section %s exists and has passed validation, adding aws_access_key_id, aws_secret_access_key, aws_session_token, %s keys to it", section, expireKey) } _, err = cfg.Section(section).NewKey("aws_access_key_id", c.AccessKeyID) @@ -92,27 +135,27 @@ func OutputFile(c *Credentials, filename string, section string) error { // Remove expired credentials. for _, s := range cfg.Sections() { if !s.HasKey(expireKey) { - log.Log.Tracef("Skipping profile %s because it does not have an %s key", s.Name(), expireKey) + log.Tracef("Skipping profile %s because it does not have an %s key", s.Name(), expireKey) continue } v, err := s.Key(expireKey).TimeFormat(time.RFC3339) if err != nil { - log.Log.Warnf("Cannot parse date (%v) in profile %s: %s", + log.Warnf("Cannot parse date (%v) in profile %s: %s", s.Key(expireKey), s.Name(), err) continue } if time.Now().UTC().Unix() > v.Unix() { - log.Log.Tracef("Removing expired credentials for profile %s", s.Name()) + log.Tracef("Removing expired credentials for profile %s", s.Name()) for _, key := range []string{"aws_access_key_id", "aws_secret_access_key", "aws_session_token", expireKey} { cfg.Section(s.Name()).DeleteKey(key) } if len(cfg.Section(s.Name()).Keys()) == 0 { - log.Log.Tracef("Removing empty profile %s", s.Name()) + log.Tracef("Removing empty profile %s", s.Name()) cfg.DeleteSection(s.Name()) } continue } - log.Log.Tracef("Profile %s expires at %s", s.Name(), v.Format(time.RFC3339)) + log.Tracef("Profile %s expires at %s", s.Name(), v.Format(time.RFC3339)) } return cfg.SaveTo(filename) @@ -144,8 +187,8 @@ func OutputEnvironment(c *Credentials, windows bool, w io.Writer) { // OutputCredentialProcess writes (prints) credentials to stdout in the format required by the AWS CLI. // The output can be used to set the credential_process option in the AWS CLI configuration file. func OutputCredentialProcess(c *Credentials, w io.Writer) { - log.Log.Trace("Writing credentials to stdout in credential_process format") - log.Log.Infof("Credentials expire at %s, in %d Minutes", c.Expiration.Format(time.RFC3339), int(c.Expiration.Sub(time.Now().UTC()).Minutes())) + log.Trace("Writing credentials to stdout in credential_process format") + log.Infof("Credentials expire at %s, in %d Minutes", c.Expiration.Format(time.RFC3339), int(c.Expiration.Sub(time.Now().UTC()).Minutes())) fmt.Fprintf( w, `{ "Version": 1, "AccessKeyId": %q, "SecretAccessKey": %q, "SessionToken": %q, "Expiration": %q }`, @@ -160,18 +203,18 @@ func OutputCredentialProcess(c *Credentials, w io.Writer) { // GetValidProfiles returns profiles which have a aws_expiration key but are not yet expired. func GetValidProfiles(filename string) ([]Profile, error) { var profiles []Profile - log.Log.WithField("filename", filename).Trace("Loading AWS credentials file") + log.WithField("filename", filename).Trace("Loading AWS credentials file") cfg, err := ini.LooseLoad(filename) if err != nil { err = fmt.Errorf("%s contains errors: %w", filename, err) - log.Log.WithError(err).Trace("Failed to load AWS credentials file") + log.WithError(err).Trace("Failed to load AWS credentials file") return nil, err } for _, s := range cfg.Sections() { if s.HasKey(expireKey) { v, err := s.Key(expireKey).TimeFormat(time.RFC3339) if err != nil { - log.Log.Warnf("Cannot parse date (%v) in section %s: %s", + log.Warnf("Cannot parse date (%v) in section %s: %s", s.Key(expireKey), s.Name(), err) continue } @@ -190,18 +233,18 @@ func GetValidProfiles(filename string) ([]Profile, error) { // returns a map of profile name to credentials func GetValidCredentials(filename string) (map[string]Credentials, error) { credentials := make(map[string]Credentials) - log.Log.WithField("filename", filename).Trace("Loading credentials file") + log.WithField("filename", filename).Trace("Loading credentials file") cfg, err := ini.LooseLoad(filename) if err != nil { err = fmt.Errorf("%s contains errors: %w", filename, err) - log.Log.WithError(err).Trace("Failed to load credentials file") + log.WithError(err).Trace("Failed to load credentials file") return nil, err } for _, s := range cfg.Sections() { if s.HasKey(expireKey) { v, err := s.Key(expireKey).TimeFormat(time.RFC3339) if err != nil { - log.Log.Warnf("Cannot parse date (%v) in section %s: %s", + log.Warnf("Cannot parse date (%v) in section %s: %s", s.Key(expireKey), s.Name(), err) continue } diff --git a/aws/aws_test.go b/aws/aws_test.go index d84cdee1..685172a7 100644 --- a/aws/aws_test.go +++ b/aws/aws_test.go @@ -14,9 +14,10 @@ import ( "github.com/allcloud-io/clisso/log" "github.com/go-ini/ini" + "github.com/stretchr/testify/assert" ) -var _ = log.NewLogger("panic", "", false) +var _, hook = log.SetupLogger("panic", "", false, true) func TestWriteToFile(t *testing.T) { id := "expiredkey" @@ -159,6 +160,33 @@ func TestWriteToFile(t *testing.T) { } } +func initConfig(filename string) error { + inifile := ini.Empty() + // add some other config options to ensure we don't overwrite them + inifile.Section("default").Key("region").SetValue("us-west-2") + inifile.Section("default").Key("output").SetValue("json") + + inifile.Section("cred-process").Key("credential_process").SetValue("echo") + + // profile setup for using a source profile + inifile.Section("child-profile").Key("source-profile").SetValue("cred-process") + inifile.Section("child-profile").Key("role_arn").SetValue("arn:aws:iam::123456789012:role/role-name") + + // mock an expired clisso temporary profile + inifile.Section("expiredprofile").Key("aws_access_key_id").SetValue("expiredkey") + inifile.Section("expiredprofile").Key("aws_secret_access_key").SetValue("expired") + inifile.Section("expiredprofile").Key("aws_session_token").SetValue("expiredtoken") + inifile.Section("expiredprofile").Key("aws_expiration").SetValue(time.Now().Add(-time.Duration(1) * time.Hour).UTC().Format(time.RFC3339)) + + // mock a valid clisso temporary profile + inifile.Section("validprofile").Key("aws_access_key_id").SetValue("testkey") + inifile.Section("validprofile").Key("aws_secret_access_key").SetValue("testsecret") + inifile.Section("validprofile").Key("aws_session_token").SetValue("testtoken") + inifile.Section("validprofile").Key("aws_expiration").SetValue(time.Now().Add(time.Duration(1) * time.Hour).UTC().Format(time.RFC3339)) + + return inifile.SaveTo(filename) + +} func TestProtectSections(t *testing.T) { id := "expiredkey" sec := "expiredsecret" @@ -173,17 +201,8 @@ func TestProtectSections(t *testing.T) { } fn := "TestProtectSections.txt" - inifile := ini.Empty() - // add some other config options to ensure we don't overwrite them - inifile.Section("default").Key("region").SetValue("us-west-2") - inifile.Section("default").Key("output").SetValue("json") - - inifile.Section("cred-process").Key("credential_process").SetValue("echo") + err := initConfig(fn) - inifile.Section("child-profile").Key("source-profile").SetValue("cred-process") - inifile.Section("child-profile").Key("role_arn").SetValue("arn:aws:iam::123456789012:role/role-name") - - err := inifile.SaveTo(fn) if err != nil { t.Fatal("Could not write INI file: ", err) } @@ -255,7 +274,7 @@ func TestProtectSections(t *testing.T) { } func TestGetValidProfiles(t *testing.T) { - fn := "test_creds.txt" + fn := "TestGetValidProfiles.txt" id := "testkey" sec := "testsecret" @@ -399,3 +418,50 @@ func TestOutputWindowsEnvironment(t *testing.T) { t.Fatalf("Wrong info written to shell: got %v want %v", got, want) } } + +func TestSetCredentialProcess(t *testing.T) { + assert := assert.New(t) + fn := "TestSetCredentialProcess.txt" + err := initConfig(fn) + assert.Nil(err, "Expected no error, but got: %v", err) + + // nothing to todo, should be skipped + p := "cred-process" + err = SetCredentialProcess(fn, p) + assert.Nil(err, "Expected no error, but got: %v", err) + assert.GreaterOrEqual(len(hook.Entries), 1, "Expected 1 or more log message, but got: %v", hook.Entries) + expected := fmt.Sprintf(infoProfileAlreadyConfigured, p) + assert.Equal(expected, hook.LastEntry().Message, "Expected '%s', but got: %v", expected, hook.LastEntry().Message) + hook.Reset() + + // set credential process on child-profile should fail + err = SetCredentialProcess(fn, "child-profile") + assert.NotNil(err, "Expected an error, but got nil") + assert.GreaterOrEqual(len(hook.Entries), 1, "Expected 1 or more log message, but got: %v", hook.Entries) + expected = fmt.Sprintf(errCannotBeUsed, "child-profile", "role_arn") + assert.Equal(expected, err.Error(), "Expected '%s', but got: %v", expected, err.Error()) + hook.Reset() + + // set credential process on expired and valid profile should work and remove the profile keys + for _, p := range []string{"expiredprofile", "validprofile"} { + err = SetCredentialProcess(fn, p) + assert.Nil(err, "Expected no error, but got: %v", err) + assert.GreaterOrEqual(len(hook.Entries), 1, "Expected 1 or more log message, but got: %v", hook.Entries) + expected = fmt.Sprintf(infoProfileConfigured, p) + assert.Equalf(expected, hook.LastEntry().Message, "Expected '%s', but got: %v", expected, hook.LastEntry().Message) + hook.Reset() + + cfg, err := ini.Load(fn) + assert.Nil(err, "Expected no error, but got: %v", err) + // the section should only have on key left + s := cfg.Section(p) + assert.Equal(1, len(s.Keys()), "Expected 1 key, but got: %v", len(s.Keys())) + // the key should be credential_process + k, err := s.GetKey("credential_process") + assert.Nil(err, "Expected no error, but got: %v", err) + expected = fmt.Sprintf(credentialProcessFormat, p) + assert.Equal(expected, k.String(), "Expected '%s', but got: %v", expected, k.String()) + } + + os.Remove(fn) +} diff --git a/aws/sts.go b/aws/sts.go index 6e0b12e9..f8b8d476 100644 --- a/aws/sts.go +++ b/aws/sts.go @@ -16,7 +16,6 @@ import ( "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/smithy-go" "github.com/icza/gog" - "github.com/sirupsen/logrus" ) const ( @@ -41,13 +40,13 @@ const ( // returns a specific error message to indicate that. In this case we return a custom error to the // caller to allow special handling such as retrying with a lower duration. func AssumeSAMLRole(PrincipalArn, RoleArn, SAMLAssertion, awsRegion string, duration int32) (*Credentials, error) { - log.Log.WithFields(logrus.Fields{ + log.WithFields(log.Fields{ "PrincipalArn": PrincipalArn, "RoleArn": RoleArn, "awsRegion": awsRegion, "duration": duration, }).Debug("Assuming role with SAML assertion") - log.Log.WithField("SAMLAssertion", SAMLAssertion).Trace("SAML assertion") + log.WithField("SAMLAssertion", SAMLAssertion).Trace("SAML assertion") creds, err := assumeSAMLRole(PrincipalArn, RoleArn, SAMLAssertion, awsRegion, duration) if err != nil { // Check if API error returned by AWS @@ -81,7 +80,7 @@ func assumeSAMLRole(PrincipalArn, RoleArn, SAMLAssertion, awsRegion string, dura // If we request credentials for China we need to provide a Chinese region idp := regexp.MustCompile(`^arn:aws-cn:iam::\d+:saml-provider\/\S+$`) if idp.MatchString(PrincipalArn) && !strings.HasPrefix(awsRegion, "cn-") { - log.Log.Trace("Changing region to cn-north-1 as we are assuming a role in China") + log.Trace("Changing region to cn-north-1 as we are assuming a role in China") awsRegion = "cn-north-1" } svc := sts.New(sts.Options{ @@ -89,11 +88,11 @@ func assumeSAMLRole(PrincipalArn, RoleArn, SAMLAssertion, awsRegion string, dura // see https://github.com/aws/aws-sdk-go-v2/issues/2392 for reasoning Credentials: nil, }) - log.Log.WithField("awsRegion", awsRegion).Trace("Setup STS") + log.WithField("awsRegion", awsRegion).Trace("Setup STS") aResp, err := svc.AssumeRoleWithSAML(ctx, &input) if err != nil { - log.Log.WithError(err).Debug("Error assuming role with SAML assertion") + log.WithError(err).Debug("Error assuming role with SAML assertion") return nil, err } @@ -102,10 +101,10 @@ func assumeSAMLRole(PrincipalArn, RoleArn, SAMLAssertion, awsRegion string, dura sessionToken := *aResp.Credentials.SessionToken expiration := *aResp.Credentials.Expiration - log.Log.WithFields(logrus.Fields{ + log.WithFields(log.Fields{ "AccessKeyID": keyID, - "SecretAccessKey": gog.If(log.Log.GetLevel() == logrus.TraceLevel, secretKey, ""), - "SessionToken": gog.If(log.Log.GetLevel() == logrus.TraceLevel, sessionToken, ""), + "SecretAccessKey": gog.If(log.GetLevel() == log.TraceLevel, secretKey, ""), + "SessionToken": gog.If(log.GetLevel() == log.TraceLevel, sessionToken, ""), "Expiration": expiration, }).Debug("Got temporary credentials") diff --git a/cmd/apps.go b/cmd/apps.go index fb398b74..0f8372cf 100644 --- a/cmd/apps.go +++ b/cmd/apps.go @@ -64,7 +64,7 @@ var cmdAppsList = &cobra.Command{ Long: "List all configured apps.", Run: func(cmd *cobra.Command, args []string) { apps := viper.GetStringMap("apps") - log.Log.Trace("Listing apps") + log.Trace("Listing apps") if len(apps) == 0 { fmt.Println("No apps configured") @@ -106,18 +106,18 @@ var cmdAppsCreateOneLogin = &cobra.Command{ // Verify app doesn't exist if exists := viper.Get("apps." + name); exists != nil { - log.Log.Fatalf("App '%s' already exists", name) + log.Fatalf("App '%s' already exists", name) } // Verify provider exists if exists := viper.Get("providers." + provider); exists == nil { - log.Log.Fatalf("Provider '%s' doesn't exist", provider) + log.Fatalf("Provider '%s' doesn't exist", provider) } // Verify provider type pType := viper.GetString(fmt.Sprintf("providers.%s.type", provider)) if pType != "onelogin" { - log.Log.Fatalf( + log.Fatalf( "Invalid provider type '%s' for a OneLogin app. Type must be 'onelogin'.", pType, ) @@ -135,9 +135,9 @@ var cmdAppsCreateOneLogin = &cobra.Command{ if duration != 0 { // Duration specified - validate value if duration < 3600 || duration > 43200 { - log.Log.Fatal("Invalid duration Specified. Valid values: 3600 - 43200") + log.Fatal("Invalid duration Specified. Valid values: 3600 - 43200") } - log.Log.Tracef("Setting duration to %d", duration) + log.Tracef("Setting duration to %d", duration) conf["duration"] = strconv.Itoa(duration) } @@ -146,9 +146,9 @@ var cmdAppsCreateOneLogin = &cobra.Command{ // Write config to file err := viper.WriteConfig() if err != nil { - log.Log.Fatalf("Error writing config: %v", err) + log.Fatalf("Error writing config: %v", err) } - log.Log.Printf("App '%s' saved to config file", name) + log.Printf("App '%s' saved to config file", name) }, } @@ -162,18 +162,18 @@ var cmdAppsCreateOkta = &cobra.Command{ // Verify app doesn't exist if exists := viper.Get("apps." + name); exists != nil { - log.Log.Fatalf("App '%s' already exists", name) + log.Fatalf("App '%s' already exists", name) } // Verify provider exists if exists := viper.Get("providers." + provider); exists == nil { - log.Log.Fatalf("Provider '%s' doesn't exist", provider) + log.Fatalf("Provider '%s' doesn't exist", provider) } // Verify provider type pType := viper.GetString(fmt.Sprintf("providers.%s.type", provider)) if pType != "okta" { - log.Log.Fatalf( + log.Fatalf( "Invalid provider type '%s' for an Okta app. Type must be 'okta'.", pType, ) @@ -187,9 +187,9 @@ var cmdAppsCreateOkta = &cobra.Command{ if duration != 0 { // Duration specified - validate value if duration < 3600 || duration > 43200 { - log.Log.Fatal("Invalid duration Specified. Valid values: 3600 - 43200") + log.Fatal("Invalid duration Specified. Valid values: 3600 - 43200") } - log.Log.Tracef("Setting duration to %d", duration) + log.Tracef("Setting duration to %d", duration) conf["duration"] = strconv.Itoa(duration) } @@ -198,9 +198,9 @@ var cmdAppsCreateOkta = &cobra.Command{ // Write config to file err := viper.WriteConfig() if err != nil { - log.Log.Fatalf("Error writing config: %v", err) + log.Fatalf("Error writing config: %v", err) } - log.Log.Printf("App '%s' saved to config file", name) + log.Printf("App '%s' saved to config file", name) }, } @@ -214,19 +214,19 @@ var cmdAppsSelect = &cobra.Command{ if app == "" { viper.Set("global.selected-app", "") - log.Log.Println("Unsetting selected app") + log.Println("Unsetting selected app") } else { if exists := viper.Get("apps." + app); exists == nil { - log.Log.Fatalf("App '%s' doesn't exist", app) + log.Fatalf("App '%s' doesn't exist", app) } - log.Log.Printf("Setting selected app to '%s'", app) + log.Printf("Setting selected app to '%s'", app) viper.Set("global.selected-app", app) } // Write config to file err := viper.WriteConfig() if err != nil { - log.Log.Fatalf("Error writing config: %v", err) + log.Fatalf("Error writing config: %v", err) } }, } @@ -240,7 +240,7 @@ var cmdAppsDelete = &cobra.Command{ app := args[0] if exists := viper.Get("apps." + app); exists == nil { - log.Log.Fatalf("App '%s' doesn't exist", app) + log.Fatalf("App '%s' doesn't exist", app) } // Delete app @@ -249,8 +249,8 @@ var cmdAppsDelete = &cobra.Command{ // Write config to file err := viper.WriteConfig() if err != nil { - log.Log.Fatalf("Error writing config: %v", err) + log.Fatalf("Error writing config: %v", err) } - log.Log.Printf("App '%s' deleted from config file", app) + log.Printf("App '%s' deleted from config file", app) }, } diff --git a/cmd/credential_process.go b/cmd/credential_process.go new file mode 100644 index 00000000..f6813ce8 --- /dev/null +++ b/cmd/credential_process.go @@ -0,0 +1,123 @@ +package cmd + +import ( + "fmt" + + "github.com/allcloud-io/clisso/aws" + "github.com/allcloud-io/clisso/log" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +var cmdCredentialProcess = &cobra.Command{ + Use: "cp", + Short: "manage credential process", + Long: `Enabled or disable credential process functionality.`, +} + +var unlockCmd = &cobra.Command{ + Use: "unlock", + Aliases: []string{"enable"}, + Short: "Unlock the credential_process functionality", + Run: func(cmd *cobra.Command, args []string) { + err := enableCredentialProcess() + if err != nil { + log.Fatal("Failed to unlock credential_process:", err) + } + log.Info("Credential_process unlocked successfully") + }, +} + +var lockCmd = &cobra.Command{ + Use: "lock", + Aliases: []string{"disable"}, + Short: "Lock the credential_process functionality", + Run: func(cmd *cobra.Command, args []string) { + err := disableCredentialProcess() + if err != nil { + log.Fatal("Failed to lock credential_process:", err) + } + log.Info("Credential_process locked successfully") + }, +} + +var lockStatusCmd = &cobra.Command{ + Use: "status", + Short: "Check the status of the credential_process functionality", + Run: func(cmd *cobra.Command, args []string) { + credentialProcess := viper.GetString("global.credential-process") + if credentialProcess == "disabled" { + // also change the exit code by logging Fatal + log.Fatal("running as credential_process is disabled") + } else { + log.Infoln("running as credential_process is enabled") + } + }, +} + +var configureCmd = &cobra.Command{ + Use: "configure", + Short: "Configure the credential_process functionality", + Run: func(cmd *cobra.Command, args []string) { + err := configureCredentialProcess() + if err != nil { + log.Fatal("Failed to configure credential_process:", err) + } + log.Info("all apps have been successfully configured as AWS profiles. You can now use them with the AWS CLI/SDK.") + }, +} + +func init() { + cmdCredentialProcess.AddCommand(unlockCmd, lockCmd, lockStatusCmd, configureCmd) + RootCmd.AddCommand(cmdCredentialProcess) + + configureCmd.Flags().StringVarP( + &output, "output", "o", defaultOutput, "where to configure credentials_process profiles", + ) +} + +func enableCredentialProcess() error { + // enable the credential_process functionality by removing the configuration + viper.Set("global.credential-process", "enabled") + err := viper.WriteConfig() + if err != nil { + return err + } + return nil +} + +func disableCredentialProcess() error { + viper.Set("global.credential-process", "disabled") + err := viper.WriteConfig() + if err != nil { + return err + } + + return nil +} + +func checkCredentialProcessActive(printToCredentialProcess bool) { + if printToCredentialProcess { + credentialProcess := viper.GetString("global.credential-process") + if credentialProcess == "disabled" { + log.Fatal("running as credential_process is disabled") + } + } +} + +func configureCredentialProcess() error { + o := preferredOutput(cmdCredentialProcess, "") + // check if output is set to credential_process or environment + if o == "credential_process" || o == "environment" { + return fmt.Errorf("output flag cannot be set to '%s' when configuring credential_process", o) + } + // configure all apps as AWS profiles + apps := viper.GetStringMap("apps") + for app := range apps { + err := aws.SetCredentialProcess(o, app) + if err != nil { + return err + } + } + return nil +} diff --git a/cmd/credential_process_test.go b/cmd/credential_process_test.go new file mode 100644 index 00000000..b3f367c7 --- /dev/null +++ b/cmd/credential_process_test.go @@ -0,0 +1,116 @@ +package cmd + +import ( + "os" + "testing" + + "github.com/allcloud-io/clisso/log" + "github.com/sirupsen/logrus" + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" +) + +var _, hook = log.SetupLogger("panic", "", false, true) + +func TestEmptyConfig(t *testing.T) { + // set viper config file to a temporary file + viper.SetConfigFile("TestEmptyConfig.yaml") + + checkCredentialProcessActive(true) + // check hook for log.Fatal + if hook.LastEntry() != nil { + t.Errorf("Expected no log.Fatal, but got: %v", hook.LastEntry()) + } + os.Remove("TestEnableCredentialProcess.yaml") +} + +func TestEnableCredentialProcess(t *testing.T) { + // set viper config file to a temporary file + viper.SetConfigFile("TestEnableCredentialProcess.yaml") + // Set up test environment + viper.Set("global.credential-process", "disabled") + + // Call the function to enable credential process + err := enableCredentialProcess() + + // Check if the function returned an error + if err != nil { + t.Errorf("Expected no error, but got: %v", err) + } + + // Check if the credential process is enabled + credentialProcess := viper.GetString("global.credential-process") + // key not there means it's enabled + if credentialProcess != "enabled" { + t.Errorf("Expected credential process to be enabled, but got: %s", credentialProcess) + } + os.Remove("TestEnableCredentialProcess.yaml") +} + +func TestDisableCredentialProcess(t *testing.T) { + // set viper config file to a temporary file + viper.SetConfigFile("TestDisableCredentialProcess.yaml") + // Set up test environment + viper.Set("global.credential-process", "enabled") + + // Call the function to disable credential process + err := disableCredentialProcess() + + // Check if the function returned an error + if err != nil { + t.Errorf("Expected no error, but got: %v", err) + } + + // Check if the credential process is disabled + credentialProcess := viper.GetString("global.credential-process") + if credentialProcess != "disabled" { + t.Errorf("Expected credential process to be disabled, but got: %s", credentialProcess) + } + os.Remove("TestDisableCredentialProcess.yaml") +} + +func TestCheckCredentialProcessActive(t *testing.T) { + assert := assert.New(t) + // Set up test environment + viper.SetConfigFile("TestCheckCredentialProcessActive.yaml") + viper.Set("global.credential-process", "disabled") + + assert.Equal("disabled", viper.GetString("global.credential-process"), "Expected credential process to be disabled, but got: %s", viper.GetString("global.credential-process")) + + // if we're not running as a credential process, the checkCredentialProcessActive function should just continue + checkCredentialProcessActive(false) + // check hook for log.Fatal + assert.Nil(hook.LastEntry(), "Expected no log.Fatal, but got: %v", hook.LastEntry()) + assert.Equal(0, len(hook.Entries), "Expected no log messages, but got: %v", hook.Entries) + + // // if we're running as a credential process, the checkCredentialProcessActive function should log a fatal message + checkCredentialProcessActive(true) + + assert.Equal(1, len(hook.Entries), "Expected 1 log message, but got: %v", hook.Entries) + if len(hook.Entries) == 1 { + assert.Equal(hook.LastEntry().Message, "running as credential_process is disabled") + assert.Equal(hook.LastEntry().Level, logrus.FatalLevel) + } + + os.Remove("TestCheckCredentialProcessActive.yaml") +} + +func TestConfigureCredentialProcess(t *testing.T) { + assert := assert.New(t) + + viper.SetConfigFile("TestConfigureCredentialProcess.yaml") + + // test default values + err := configureCredentialProcess() + assert.Nil(err, "Expected no error, but got: %v", err) + + for _, o := range []string{"credential_process", "environment"} { + viper.Set("global.output", o) + err = configureCredentialProcess() + assert.NotNil(err, "Expected an error, but got nil") + assert.EqualError(err, "output flag cannot be set to '"+o+"' when configuring credential_process") + } + + os.Remove("TestConfigureCredentialProcess.yaml") + +} diff --git a/cmd/get.go b/cmd/get.go index 28353630..4406b7ac 100644 --- a/cmd/get.go +++ b/cmd/get.go @@ -74,53 +74,13 @@ func init() { lock, err = lockfile.New(filepath.Join(os.TempDir(), "clisso.lock")) if err != nil { - log.Log.Fatalf("Failed to create lock: %v", err) + log.Fatalf("Failed to create lock: %v", err) } } -func preferredOutput(cmd *cobra.Command, app string) string { - // Order of preference: - // * output flag - // * write-to-file flag (deprecated) - // * app specific config file - // * global config file - // * default to ~/.aws/credentials - out, err := cmd.Flags().GetString("output") - if err != nil { - log.Log.Warnf("Error getting output flag: %v", err) - } - if out != "" && out != defaultOutput { - log.Log.Tracef("output flag sets output to: %s", out) - return out - } - - out, err = cmd.Flags().GetString("write-to-file") - if err != nil { - log.Log.Warnf("Error getting write-to-file flag: %v", err) - } - if out != "" && out != defaultOutput { - log.Log.Tracef("write-to-file flag sets output: %s", out) - return out - } - - out = viper.GetString(fmt.Sprintf("apps.%s.output", app)) - if out != "" { - log.Log.Tracef("App specific config sets output to: %s", out) - return out - } - - out = viper.GetString("global.output") - if out != "" { - log.Log.Tracef("Global config sets output to: %s", out) - return out - } - - return defaultOutput -} - func setOutput(cmd *cobra.Command, app string) { o := preferredOutput(cmd, app) - log.Log.Tracef("Preferred output: %s", o) + log.Tracef("Preferred output: %s", o) writeToFile = "" switch o { case "environment": @@ -145,7 +105,7 @@ func processCredentials(creds *aws.Credentials, app string) error { if cacheCredentials { if err := writeCredentialsToFile(creds, app, cacheToFile); err != nil { - log.Log.Errorf("writing credentials to file: %v", err) + log.Errorf("writing credentials to file: %v", err) } } @@ -159,14 +119,14 @@ func processCredentials(creds *aws.Credentials, app string) error { } func writeCredentialsToFile(creds *aws.Credentials, app, file string) error { - log.Log.Tracef("Writing credentials to '%s'", file) + log.Tracef("Writing credentials to '%s'", file) path, err := homedir.Expand(file) if err != nil { return fmt.Errorf("expanding config file path: %v", err) } credsFileParentDir := filepath.Dir(path) if _, err := os.Stat(credsFileParentDir); os.IsNotExist(err) { - log.Log.Warnf("Credentials directory '%s' does not exist - creating it", credsFileParentDir) + log.Warnf("Credentials directory '%s' does not exist - creating it", credsFileParentDir) // Lets default to strict permissions on the folders we create err = os.MkdirAll(credsFileParentDir, 0700) if err != nil { @@ -177,7 +137,7 @@ func writeCredentialsToFile(creds *aws.Credentials, app, file string) error { if err := aws.OutputFile(creds, path, app); err != nil { return fmt.Errorf("writing credentials to file: %v", err) } - log.Log.Printf("Credentials written successfully to '%s'", path) + log.Printf("Credentials written successfully to '%s'", path) return nil } @@ -213,15 +173,15 @@ func awsRegion(app string) string { func getCachedCredential(app string) (*aws.Credentials, error) { // get the credentials from the cache file - log.Log.Tracef("Looking for cached credentials in '%s'", cacheToFile) + log.Tracef("Looking for cached credentials in '%s'", cacheToFile) credentialFile, err := homedir.Expand(cacheToFile) if err != nil { - log.Log.Fatalf("Failed to expand home: %s", err) + log.Fatalf("Failed to expand home: %s", err) } profiles, err := aws.GetValidCredentials(credentialFile) if err != nil { - log.Log.Fatalf("Failed to retrieve non-expired credentials: %s", err) + log.Fatalf("Failed to retrieve non-expired credentials: %s", err) } if len(profiles) == 0 { @@ -252,7 +212,7 @@ If no app is specified, the selected app (if configured) will be assumed.`, selected := viper.GetString("global.selected-app") if selected == "" { // No default app configured. - log.Log.Fatal("No app specified and no default app configured") + log.Fatal("No app specified and no default app configured") } app = selected } else { @@ -262,15 +222,15 @@ If no app is specified, the selected app (if configured) will be assumed.`, provider := viper.GetString(fmt.Sprintf("apps.%s.provider", app)) if provider == "" { - log.Log.Fatalf("Could not get provider for app '%s'", app) + log.Fatalf("Could not get provider for app '%s'", app) } pType := viper.GetString(fmt.Sprintf("providers.%s.type", provider)) if pType == "" { - log.Log.Fatalf("Could not get provider type for provider '%s'", provider) + log.Fatalf("Could not get provider type for provider '%s'", provider) } - log.Log.Infof("Getting credentials for app '%s' using provider '%s' (type: %s)", app, provider, pType) + log.Infof("Getting credentials for app '%s' using provider '%s' (type: %s)", app, provider, pType) // allow preferred "arn" to be specified in the config file for each app // if this is not specified the value will be empty ("") @@ -286,11 +246,11 @@ If no app is specified, the selected app (if configured) will be assumed.`, defer unlock() if printToCredentialProcess && cacheCredentials { - log.Log.Trace("Using --cache-credentials and --output-process") + log.Trace("Using --cache-credentials and --output-process") // we need to cache the credentials to a file and return valid credentials instead of constantly hitting the IdPs credential, err := getCachedCredential(app) if err != nil { - log.Log.WithError(err).Debugf("Failed to find cached credentials for app '%s'", app) + log.WithError(err).Debugf("Failed to find cached credentials for app '%s'", app) } if credential != nil { aws.OutputCredentialProcess(credential, os.Stdout) @@ -298,29 +258,31 @@ If no app is specified, the selected app (if configured) will be assumed.`, } } + checkCredentialProcessActive(printToCredentialProcess) + interactive := !printToShell && !printToCredentialProcess if pType == "onelogin" { creds, err := onelogin.Get(app, provider, pArn, awsRegion, duration, interactive) if err != nil { - log.Log.Fatal("Could not get temporary credentials: ", err) + log.Fatal("Could not get temporary credentials: ", err) } // Process credentials err = processCredentials(creds, app) if err != nil { - log.Log.Fatalf("Error processing credentials: %v", err) + log.Fatalf("Error processing credentials: %v", err) } } else if pType == "okta" { creds, err := okta.Get(app, provider, pArn, awsRegion, duration, interactive) if err != nil { - log.Log.Fatal("Could not get temporary credentials: ", err) + log.Fatal("Could not get temporary credentials: ", err) } // Process credentials err = processCredentials(creds, app) if err != nil { - log.Log.Fatalf("Error processing credentials: %v", err) + log.Fatalf("Error processing credentials: %v", err) } } else { - log.Log.Fatalf("Unsupported identity provider type '%s' for app '%s'", pType, app) + log.Fatalf("Unsupported identity provider type '%s' for app '%s'", pType, app) } if interactive { printStatus() @@ -336,15 +298,15 @@ func ensureLocked() { if err == nil { return } - log.Log.Tracef("Sleeping, failed to get lock: %v", err) + log.Tracef("Sleeping, failed to get lock: %v", err) time.Sleep(100 * time.Millisecond) } - log.Log.Fatalf("Failed to get lock") + log.Fatalf("Failed to get lock") } func unlock() { if err := lock.Unlock(); err != nil { - log.Log.Fatalf("Failed to unlock: %v", err) + log.Fatalf("Failed to unlock: %v", err) } } diff --git a/cmd/get_test.go b/cmd/get_test.go index cb65b2c9..e26f65fe 100644 --- a/cmd/get_test.go +++ b/cmd/get_test.go @@ -9,11 +9,10 @@ import ( "testing" "github.com/allcloud-io/clisso/log" - "github.com/spf13/cobra" "github.com/spf13/viper" ) -var _ = log.NewLogger("panic", "", false) +var _, _ = log.SetupLogger("panic", "", false, true) var testdata = []struct { app int32 @@ -37,42 +36,3 @@ func TestSessionDuration(t *testing.T) { } } } - -func TestPreferredOutput(t *testing.T) { - testdata := []struct { - outputFlag string - writeToFileFlag string - appConfig string - globalConfig string - result string - }{ - {"environment", "", "", "", "environment"}, - {"credential_process", "", "", "", "credential_process"}, - {"", "", "~/.aws/test", "", "~/.aws/test"}, - {"", "", "", "test", "test"}, - {defaultOutput, "", "", "", defaultOutput}, - {defaultOutput, "", "credential_process", "", "credential_process"}, - {defaultOutput, "", "credential_process", "~/global", "credential_process"}, - {defaultOutput, "", "", "~/global", "~/global"}, - {"~/test", "", "credential_process", "", "~/test"}, - {"~/test", "", "credential_process", "~/global", "~/test"}, - {"~/test", "", "", "~/global", "~/test"}, - } - for _, tc := range testdata { - viper.Set("apps.test.output", tc.appConfig) - viper.Set("global.output", tc.globalConfig) - - cmd := &cobra.Command{} - cmd.Flags().StringVarP( - &output, "output", "o", tc.outputFlag, "fake", - ) - cmd.Flags().StringVarP( - &output, "write-to-file", "f", tc.outputFlag, "fake legacy flag", - ) - - res := preferredOutput(cmd, "test") - if res != tc.result { - t.Fatalf("Invalid output: got %v, want: %v", res, tc.result) - } - } -} diff --git a/cmd/helpers.go b/cmd/helpers.go index 59698dc7..25f92565 100644 --- a/cmd/helpers.go +++ b/cmd/helpers.go @@ -6,13 +6,57 @@ package cmd import ( + "fmt" + "github.com/allcloud-io/clisso/log" "github.com/spf13/cobra" + "github.com/spf13/viper" ) func mandatoryFlag(cmd *cobra.Command, name string) { err := cmd.MarkFlagRequired(name) if err != nil { - log.Log.Fatalf("Error marking flag %s as required: %v", name, err) + log.Fatalf("Error marking flag %s as required: %v", name, err) + } +} + +func preferredOutput(cmd *cobra.Command, app string) string { + // Order of preference: + // * output flag + // * write-to-file flag (deprecated) + // * app specific config file + // * global config file + // * default to ~/.aws/credentials + out, err := cmd.Flags().GetString("output") + if err != nil { + log.Warnf("Error getting output flag: %v", err) + } + if out != "" && out != defaultOutput { + log.Tracef("output flag sets output to: %s", out) + return out + } + + out, err = cmd.Flags().GetString("write-to-file") + if err != nil { + log.Warnf("Error getting write-to-file flag: %v", err) + } + if out != "" && out != defaultOutput { + log.Tracef("write-to-file flag sets output: %s", out) + return out + } + if app != "" { + out = viper.GetString(fmt.Sprintf("apps.%s.output", app)) + if out != "" { + log.Tracef("App specific config sets output to: %s", out) + return out + } } + + out = viper.GetString("global.output") + if out != "" { + log.Tracef("Global config sets output to: %s", out) + return out + } + + return defaultOutput } diff --git a/cmd/helpers_test.go b/cmd/helpers_test.go new file mode 100644 index 00000000..2a73797e --- /dev/null +++ b/cmd/helpers_test.go @@ -0,0 +1,47 @@ +package cmd + +import ( + "testing" + + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +func TestPreferredOutput(t *testing.T) { + testdata := []struct { + outputFlag string + writeToFileFlag string + appConfig string + globalConfig string + result string + }{ + {"environment", "", "", "", "environment"}, + {"credential_process", "", "", "", "credential_process"}, + {"", "", "~/.aws/test", "", "~/.aws/test"}, + {"", "", "", "test", "test"}, + {defaultOutput, "", "", "", defaultOutput}, + {defaultOutput, "", "credential_process", "", "credential_process"}, + {defaultOutput, "", "credential_process", "~/global", "credential_process"}, + {defaultOutput, "", "", "~/global", "~/global"}, + {"~/test", "", "credential_process", "", "~/test"}, + {"~/test", "", "credential_process", "~/global", "~/test"}, + {"~/test", "", "", "~/global", "~/test"}, + } + for _, tc := range testdata { + viper.Set("apps.test.output", tc.appConfig) + viper.Set("global.output", tc.globalConfig) + + cmd := &cobra.Command{} + cmd.Flags().StringVarP( + &output, "output", "o", tc.outputFlag, "fake", + ) + cmd.Flags().StringVarP( + &output, "write-to-file", "f", tc.outputFlag, "fake legacy flag", + ) + + res := preferredOutput(cmd, "test") + if res != tc.result { + t.Fatalf("Invalid output: got %v, want: %v", res, tc.result) + } + } +} diff --git a/cmd/providers.go b/cmd/providers.go index a9f0c3a7..ca5000e2 100644 --- a/cmd/providers.go +++ b/cmd/providers.go @@ -77,7 +77,7 @@ var cmdProvidersList = &cobra.Command{ providers := viper.GetStringMap("providers") if len(providers) == 0 { - log.Log.Println("No providers configured") + log.Println("No providers configured") return } @@ -88,7 +88,7 @@ var cmdProvidersList = &cobra.Command{ } sort.Strings(keys) for _, k := range keys { - log.Log.Println(k) + log.Println(k) } }, } @@ -104,16 +104,16 @@ var cmdProvidersPassword = &cobra.Command{ fmt.Printf("Please enter the password for the '%s' provider: ", provider) pass, err := term.ReadPassword(int(syscall.Stdin)) if err != nil { - log.Log.Fatalf("Could not read password") + log.Fatalf("Could not read password") } keyChain := keychain.DefaultKeychain{} err = keyChain.Set(provider, pass) if err != nil { - log.Log.Fatalf("Could not save to keychain: %+v", err) + log.Fatalf("Could not save to keychain: %+v", err) } - log.Log.Printf("Saved password for Provider '%s'", provider) + log.Printf("Saved password for Provider '%s'", provider) }, } @@ -133,13 +133,13 @@ var cmdProvidersCreateOneLogin = &cobra.Command{ // Verify provider doesn't exist if exists := viper.Get("providers." + name); exists != nil { - log.Log.Fatalf("Provider '%s' already exists", name) + log.Fatalf("Provider '%s' already exists", name) } switch region { case "US", "EU": default: - log.Log.Fatal("Region must be either US or EU") + log.Fatal("Region must be either US or EU") } conf := map[string]string{ @@ -153,7 +153,7 @@ var cmdProvidersCreateOneLogin = &cobra.Command{ if providerDuration != 0 { // Duration specified - validate value if providerDuration < 3600 || providerDuration > 43200 { - log.Log.Fatal("Invalid duration Specified. Valid values: 3600 - 43200") + log.Fatal("Invalid duration Specified. Valid values: 3600 - 43200") } conf["duration"] = strconv.Itoa(providerDuration) } @@ -162,9 +162,9 @@ var cmdProvidersCreateOneLogin = &cobra.Command{ // Write config to file err := viper.WriteConfig() if err != nil { - log.Log.Fatalf("Error writing config: %v", err) + log.Fatalf("Error writing config: %v", err) } - log.Log.Printf("Provider '%s' saved to config file", name) + log.Printf("Provider '%s' saved to config file", name) }, } @@ -178,7 +178,7 @@ var cmdProvidersCreateOkta = &cobra.Command{ // Verify provider doesn't exist if exists := viper.Get("providers." + name); exists != nil { - log.Log.Fatalf("Provider '%s' already exists", name) + log.Fatalf("Provider '%s' already exists", name) } conf := map[string]string{ @@ -189,7 +189,7 @@ var cmdProvidersCreateOkta = &cobra.Command{ if providerDuration != 0 { // Duration specified - validate value if providerDuration < 3600 || providerDuration > 43200 { - log.Log.Fatal("Invalid duration Specified. Valid values: 3600 - 43200") + log.Fatal("Invalid duration Specified. Valid values: 3600 - 43200") } conf["duration"] = strconv.Itoa(providerDuration) } @@ -198,8 +198,8 @@ var cmdProvidersCreateOkta = &cobra.Command{ // Write config to file err := viper.WriteConfig() if err != nil { - log.Log.Fatalf("Error writing config: %v", err) + log.Fatalf("Error writing config: %v", err) } - log.Log.Printf("Provider '%s' saved to config file", name) + log.Printf("Provider '%s' saved to config file", name) }, } diff --git a/cmd/root.go b/cmd/root.go index d22027e6..409f3b41 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -111,7 +111,7 @@ func initConfig(cmd *cobra.Command) error { } else { home, err := homedir.Dir() if err != nil { - log.Log.Fatalf("Error getting home directory: %v", err) + log.Fatalf("Error getting home directory: %v", err) } viper.SetConfigType("yaml") @@ -134,7 +134,7 @@ func initConfig(cmd *cobra.Command) error { panic(fmt.Errorf("can't read config: %v", err)) } bindFlags(cmd, viper.GetViper()) - _ = log.NewLogger(logLevel, logFile, logFile != "") + _, _ = log.SetupLogger(logLevel, logFile, logFile != "", false) return nil } diff --git a/cmd/status.go b/cmd/status.go index 12e60ca8..abde6b32 100644 --- a/cmd/status.go +++ b/cmd/status.go @@ -28,7 +28,7 @@ func init() { ) err := viper.BindPFlag("global.output", cmdStatus.Flags().Lookup("read-from-file")) if err != nil { - log.Log.Fatalf("Error binding flag global.output: %v", err) + log.Fatalf("Error binding flag global.output: %v", err) } } @@ -44,16 +44,17 @@ var cmdStatus = &cobra.Command{ func printStatus() { credentialFile, err := homedir.Expand(viper.GetString("global.output")) if err != nil { - log.Log.Fatalf("Failed to expand home: %s", err) + log.Fatalf("Failed to expand home: %s", err) } - log.Log.Trace("Credential file: ", credentialFile) + log.Trace("Credential file: ", credentialFile) if credentialFile == "credential_process" || credentialFile == "environment" { + // TODO: Implement checking the cache file for valid credentials return } profiles, err := aws.GetValidProfiles(credentialFile) if err != nil { - log.Log.Fatalf("Failed to retrieve non-expired credentials: %s", err) + log.Fatalf("Failed to retrieve non-expired credentials: %s", err) } if len(profiles) == 0 { @@ -64,7 +65,7 @@ func printStatus() { table := tablewriter.NewWriter(os.Stdout) table.SetHeader([]string{"App", "Expire At", "Remaining"}) - log.Log.Print("The following apps currently have valid credentials:") + log.Print("The following apps currently have valid credentials:") for _, p := range profiles { table.Append([]string{p.Name, fmt.Sprintf("%d", p.ExpireAtUnix), p.LifetimeLeft.Round(time.Second).String()}) } diff --git a/config/config.go b/config/config.go index fc659600..c34ff724 100644 --- a/config/config.go +++ b/config/config.go @@ -11,7 +11,6 @@ import ( "github.com/allcloud-io/clisso/log" "github.com/icza/gog" - "github.com/sirupsen/logrus" "github.com/spf13/viper" ) @@ -28,14 +27,14 @@ type OneLoginProviderConfig struct { // GetOneLoginProvider returns a OneLoginProviderConfig struct containing the configuration for // provider p. func GetOneLoginProvider(p string) (*OneLoginProviderConfig, error) { - log.Log.WithField("provider", p).Trace("Reading OneLogin provider config") + log.WithField("provider", p).Trace("Reading OneLogin provider config") clientSecret := viper.GetString(fmt.Sprintf("providers.%s.client-secret", p)) clientID := viper.GetString(fmt.Sprintf("providers.%s.client-id", p)) subdomain := viper.GetString(fmt.Sprintf("providers.%s.subdomain", p)) username := viper.GetString(fmt.Sprintf("providers.%s.username", p)) region := viper.GetString(fmt.Sprintf("providers.%s.region", p)) - log.Log.WithFields(logrus.Fields{ - "clientSecret": gog.If(log.Log.GetLevel() == logrus.TraceLevel, clientSecret, ""), + log.WithFields(log.Fields{ + "clientSecret": gog.If(log.GetLevel() == log.TraceLevel, clientSecret, ""), "clientID": clientID, "subdomain": subdomain, "username": username, diff --git a/config/config_test.go b/config/config_test.go new file mode 100644 index 00000000..cf057ed1 --- /dev/null +++ b/config/config_test.go @@ -0,0 +1,57 @@ +package config + +import ( + "testing" + + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" +) + +func TestOneLoginConfig(t *testing.T) { + assert := assert.New(t) + // use the sample config file + viper.SetConfigFile("../sample_config.yaml") + err := viper.ReadInConfig() + assert.Nil(err) + + onelogin, err := GetOneLoginProvider("sample-onelogin-provider") + assert.Nil(err) + assert.Equal("abcdef-sample-client-id-ghijkl", onelogin.ClientID) + assert.Equal("123456-sample-client-secret-789012", onelogin.ClientSecret) + assert.Equal("sample", onelogin.Subdomain) + assert.Equal("example@example.com", onelogin.Username) + + app, err := GetOneLoginApp("sample-app-1") + assert.Nil(err) + assert.Equal("123456", app.ID) + assert.Equal("sample-onelogin-provider", app.Provider) + + // okta app is missing fields + app, err = GetOneLoginApp("sample-app-2") + assert.Error(err) + assert.Nil(app) + assert.Errorf(err, "app-id config value must bet set") +} + +func TestOktaConfig(t *testing.T) { + assert := assert.New(t) + // use the sample config file + viper.SetConfigFile("../sample_config.yaml") + err := viper.ReadInConfig() + assert.Nil(err) + okta, err := GetOktaProvider("sample-okta-provider") + assert.Nil(err) + assert.Equal("https://xxxxxxxx.oktapreview.com", okta.BaseURL) + assert.Equal("example@example.com", okta.Username) + + app, err := GetOktaApp("sample-app-2") + assert.Nil(err) + assert.Equal("https://xxxxxxxx.oktapreview.com/home/amazon_aws/xxxxxxxxxxxxxxxxxxxx/137", app.URL) + assert.Equal("sample-okta-provider", app.Provider) + + // onelogin app is missing fields + app, err = GetOktaApp("sample-app-1") + assert.Error(err) + assert.Nil(app) + assert.Errorf(err, "url config value must be set") +} diff --git a/go.mod b/go.mod index 2d8518c5..8237106a 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/spf13/cobra v1.8.0 github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.19.0 + github.com/stretchr/testify v1.9.0 github.com/zalando/go-keyring v0.2.5 golang.org/x/net v0.27.0 golang.org/x/term v0.22.0 @@ -34,6 +35,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.15 // indirect github.com/beevik/etree v1.2.0 // indirect github.com/danieljoos/wincred v1.2.0 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fatih/color v1.14.1 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/godbus/dbus/v5 v5.1.0 // indirect @@ -46,6 +48,7 @@ require ( github.com/mattn/go-runewidth v0.0.15 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rivo/uniseg v0.4.4 // indirect github.com/russellhaering/goxmldsig v1.4.0 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect diff --git a/keychain/keychain.go b/keychain/keychain.go index 3bf823d0..fc8176df 100644 --- a/keychain/keychain.go +++ b/keychain/keychain.go @@ -43,15 +43,15 @@ func (DefaultKeychain) Set(provider string, password []byte) (err error) { // and just ask the user for the password instead. Error could be anything from access denied to // password not found. func (DefaultKeychain) Get(provider string) (pw []byte, err error) { - log.Log.WithField("provider", provider).Trace("Reading password from keychain") + log.WithField("provider", provider).Trace("Reading password from keychain") pass, err := get(provider) if err != nil { - log.Log.WithError(err).Trace("Couldn't read password from keychain") + log.WithError(err).Trace("Couldn't read password from keychain") fmt.Printf("Please enter %s password: ", provider) pass, err = term.ReadPassword(int(syscall.Stdin)) if err != nil { err = fmt.Errorf("couldn't read password from terminal: %w", err) - log.Log.WithError(err).Trace("Couldn't read password from terminal") + log.WithError(err).Trace("Couldn't read password from terminal") return nil, err } } diff --git a/keychain/keychain_test.go b/keychain/keychain_test.go index 6e71d3a0..4ac5b556 100644 --- a/keychain/keychain_test.go +++ b/keychain/keychain_test.go @@ -12,7 +12,7 @@ import ( "github.com/allcloud-io/clisso/log" ) -var _ = log.NewLogger("panic", "", false) +var _, _ = log.SetupLogger("panic", "", false, true) func randSeq(n int, letters []rune) []byte { b := make([]rune, n) diff --git a/log/log.go b/log/log.go index 6b4ab123..4a17789c 100644 --- a/log/log.go +++ b/log/log.go @@ -8,28 +8,47 @@ import ( "github.com/mitchellh/go-homedir" "github.com/rifflock/lfshook" "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus/hooks/test" ) -var Log *logrus.Logger +var logger = logrus.New() +var Hook *test.Hook +var isTestLogger bool -func NewLogger(logLevelFlag, logFilePath string, enableLogFile bool) *logrus.Logger { - if Log != nil { - Log.Tracef("Logger already initialized") - return Log +func SetupLogger(logLevelFlag, logFilePath string, enableLogFile, testLogger bool) (*logrus.Logger, *test.Hook) { + // set isTestLogger to true if testLogger is true or if it was already set to true + isTestLogger = testLogger || isTestLogger + // if testLogger is true, return a test logger + if isTestLogger { + return setupTestLogger() } + return setupProdLogger(logLevelFlag, logFilePath, enableLogFile), nil +} +func setupTestLogger() (*logrus.Logger, *test.Hook) { + if Hook != nil { + // don't overwrite the test hook, we might loose the logs + return logger, Hook + } + logger, Hook = test.NewNullLogger() + logger.ExitFunc = func(int) {} + return logger, Hook +} + +func setupProdLogger(logLevelFlag, logFilePath string, enableLogFile bool) *logrus.Logger { // parse log level flag and set log level logLevel, err := logrus.ParseLevel(logLevelFlag) if err != nil { logrus.Fatalf("Error parsing log level: %v", err) } - Log = logrus.New() - Log.SetLevel(logLevel) + logger.SetLevel(logLevel) + // reset Hooks to avoid duplicate entries + logger.Hooks = make(logrus.LevelHooks) if enableLogFile { logFile, err := homedir.Expand(logFilePath) if err != nil { - Log.Fatalf("Error expanding homedir: %v", err) + logger.Fatalf("Error expanding homedir: %v", err) } // set all log levels to write to the log file @@ -42,19 +61,177 @@ func NewLogger(logLevelFlag, logFilePath string, enableLogFile bool) *logrus.Log logrus.FatalLevel: logFile, logrus.PanicLevel: logFile, } - Log.Hooks.Add(lfshook.NewHook( + // add the hook to the logger + logger.Hooks.Add(lfshook.NewHook( pathMap, &logrus.JSONFormatter{}, )) - Log.Out = io.Discard + logger.Out = io.Discard } else { if runtime.GOOS == "windows" { // Handle terminal colors on Windows machines. // TODO, check if still required with the switch to logrus - Log.SetOutput(colorable.NewColorableStdout()) + logger.SetOutput(colorable.NewColorableStdout()) } - Log.SetFormatter(&logrus.TextFormatter{PadLevelText: true}) + logger.SetFormatter(&logrus.TextFormatter{PadLevelText: true}) } - Log.Warnf("Log level set to %s", logLevelFlag) - return Log + logger.Warnf("Log level set to %s", logLevelFlag) + return logger +} + +// Fatal is a wrapper for Logrus Fatal +func Fatal(args ...interface{}) { + logger.Fatal(args...) +} + +// Fatalf is a wrapper for Logrus Fatalf +func Fatalf(format string, args ...interface{}) { + logger.Fatalf(format, args...) +} + +// Fatalln is a wrapper for Logrus Fatalln +func Fatalln(args ...interface{}) { + logger.Fatalln(args...) +} + +// Panic is a wrapper for Logrus Panic +func Panic(args ...interface{}) { + logger.Panic(args...) +} + +// Panicf is a wrapper for Logrus Panicf +func Panicf(format string, args ...interface{}) { + logger.Panicf(format, args...) +} + +// Panicln is a wrapper for Logrus Panicln +func Panicln(args ...interface{}) { + logger.Panicln(args...) +} + +// Print is a wrapper for Logrus Print +func Print(args ...interface{}) { + logger.Print(args...) +} + +// Printf is a wrapper for Logrus Printf +func Printf(format string, args ...interface{}) { + logger.Printf(format, args...) +} + +// Println is a wrapper for Logrus Println +func Println(args ...interface{}) { + logger.Println(args...) +} + +// Error is a wrapper for Logrus Error +func Error(args ...interface{}) { + logger.Error(args...) +} + +// Errorf is a wrapper for Logrus Errorf +func Errorf(format string, args ...interface{}) { + logger.Errorf(format, args...) +} + +// Errorln is a wrapper for Logrus Errorln +func Errorln(args ...interface{}) { + logger.Errorln(args...) +} + +// Warn is a wrapper for Logrus Warn +func Warn(args ...interface{}) { + logger.Warn(args...) +} + +// Warnf is a wrapper for Logrus Warnf +func Warnf(format string, args ...interface{}) { + logger.Warnf(format, args...) +} + +// Warnln is a wrapper for Logrus Warnln +func Warnln(args ...interface{}) { + logger.Warnln(args...) +} + +// Info is a wrapper for Logrus Info +func Info(args ...interface{}) { + logger.Info(args...) +} + +// Infof is a wrapper for Logrus Infof +func Infof(format string, args ...interface{}) { + logger.Infof(format, args...) +} + +// Infoln is a wrapper for Logrus Infoln +func Infoln(args ...interface{}) { + logger.Infoln(args...) +} + +// Debug is a wrapper for Logrus Debug +func Debug(args ...interface{}) { + logger.Debug(args...) +} + +// Debugf is a wrapper for Logrus Debugf +func Debugf(format string, args ...interface{}) { + logger.Debugf(format, args...) +} + +// Debugln is a wrapper for Logrus Debugln +func Debugln(args ...interface{}) { + logger.Debugln(args...) +} + +// Trace is a wrapper for Logrus Trace +func Trace(args ...interface{}) { + logger.Trace(args...) +} + +// Tracef is a wrapper for Logrus Tracef +func Tracef(format string, args ...interface{}) { + logger.Tracef(format, args...) } + +// Traceln is a wrapper for Logrus Traceln +func Traceln(args ...interface{}) { + logger.Traceln(args...) +} + +// WithFields is a wrapper for Logrus WithFields +func WithFields(fields Fields) *logrus.Entry { + return logger.WithFields(logrus.Fields(fields)) +} + +// WithField is a wrapper for Logrus WithField +func WithField(key string, value interface{}) *logrus.Entry { + return logger.WithField(key, value) +} + +// WithError is a wrapper for Logrus WithError +func WithError(err error) *logrus.Entry { + return logger.WithError(err) +} + +// wrap log.Fields with a type so we can use it in the WithFields method +type Fields logrus.Fields + +// GetLevel returns the current log level +func GetLevel() Level { + return Level(logger.GetLevel()) +} + +// SetLevel sets the log level +func SetLevel(level Level) { + logger.SetLevel(logrus.Level(level)) +} + +// Level is a wrapper for Logrus Level +type Level logrus.Level + +// TraceLevel is a wrapper for Logrus TraceLevel +const TraceLevel = Level(logrus.TraceLevel) + +// DebugLevel is a wrapper for Logrus DebugLevel +const DebugLevel = Level(logrus.DebugLevel) diff --git a/okta/client.go b/okta/client.go index 1be57fa6..48d7bcb9 100644 --- a/okta/client.go +++ b/okta/client.go @@ -17,7 +17,6 @@ import ( "github.com/PuerkitoBio/goquery" "github.com/allcloud-io/clisso/log" - "github.com/sirupsen/logrus" "golang.org/x/net/publicsuffix" ) @@ -187,7 +186,7 @@ func (c *Client) LaunchApp(p *LaunchAppParams) (*string, error) { // using the client, handles any HTTP-related errors and returns any data as a string. func (c *Client) doRequest(r *http.Request) (string, error) { resp, err := c.Do(r) - log.Log.WithFields(logrus.Fields{ + log.WithFields(log.Fields{ "status": resp.Status, "url": resp.Request.URL, "host": resp.Request.Host, diff --git a/okta/client_test.go b/okta/client_test.go index 6ee9f4f6..18a742b7 100644 --- a/okta/client_test.go +++ b/okta/client_test.go @@ -14,7 +14,7 @@ import ( "github.com/allcloud-io/clisso/log" ) -var _ = log.NewLogger("panic", "", false) +var _, _ = log.SetupLogger("panic", "", false, true) func getTestServer(data string) *httptest.Server { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/okta/get.go b/okta/get.go index 6665a063..f571b506 100644 --- a/okta/get.go +++ b/okta/get.go @@ -16,7 +16,6 @@ import ( "github.com/allcloud-io/clisso/saml" "github.com/allcloud-io/clisso/spinner" "github.com/icza/gog" - "github.com/sirupsen/logrus" ) const ( @@ -33,7 +32,7 @@ var ( // Get gets temporary credentials for the given app. func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool) (*aws.Credentials, error) { - log.Log.WithFields(logrus.Fields{ + log.WithFields(log.Fields{ "app": app, "provider": provider, "pArn": pArn, @@ -80,10 +79,10 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool // Get session token s.Start() - log.Log.WithFields(logrus.Fields{ + log.WithFields(log.Fields{ "Username": user, // print password only in Trace Log Level - "Password": gog.If(log.Log.GetLevel() == logrus.TraceLevel, string(pass), ""), + "Password": gog.If(log.GetLevel() == log.TraceLevel, string(pass), ""), }).Debug("Calling GetSessionToken") resp, err := c.GetSessionToken(&GetSessionTokenParams{ Username: user, @@ -93,7 +92,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool if err != nil { return nil, fmt.Errorf("getting session token: %v", err) } - log.Log.WithField("Status", resp.Status).Trace("GetSessionToken done") + log.WithField("Status", resp.Status).Trace("GetSessionToken done") var st string @@ -104,7 +103,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool case StatusMFARequired: factor := resp.Embedded.Factors[0] stateToken := resp.StateToken - log.Log.WithFields(logrus.Fields{ + log.WithFields(log.Fields{ "factorID": factor.ID, "factorLink": factor.Links.Verify.Href, "stateToken": stateToken, @@ -174,7 +173,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool // Handle failed MFA verification (verification rejected or timed out) if vfResp.Status != VerifyFactorStatusSuccess { err = fmt.Errorf("MFA verification failed") - log.Log.WithField("status", vfResp.Status).WithError(err).Warn("MFA verification failed") + log.WithField("status", vfResp.Status).WithError(err).Warn("MFA verification failed") return nil, fmt.Errorf("MFA verification failed") } @@ -185,7 +184,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool // Launch Okta app with session token s.Start() - log.Log.WithFields(logrus.Fields{ + log.WithFields(log.Fields{ "SessionToken": st, "URL": a.URL, }).Trace("Calling LaunchApp") @@ -206,7 +205,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool if err != nil { if err.Error() == aws.ErrDurationExceeded { - log.Log.Warn(aws.DurationExceededMessage) + log.Warn(aws.DurationExceededMessage) s.Start() creds, err = aws.AssumeSAMLRole(arn.Provider, arn.Role, *samlAssertion, awsRegion, 3600) s.Stop() diff --git a/onelogin/client.go b/onelogin/client.go index 4bb7fb46..c0e5c2dc 100644 --- a/onelogin/client.go +++ b/onelogin/client.go @@ -15,7 +15,6 @@ import ( "time" "github.com/allcloud-io/clisso/log" - "github.com/sirupsen/logrus" ) // Client represents a OneLogin API client. @@ -116,7 +115,7 @@ func (c *Client) doRequest(r *http.Request) (string, error) { resp, err := c.Do(r) if resp != nil { - log.Log.WithFields(logrus.Fields{ + log.WithFields(log.Fields{ "status": resp.Status, "url": resp.Request.URL, "host": resp.Request.Host, diff --git a/onelogin/client_test.go b/onelogin/client_test.go index 8ae6eb53..1491d39f 100644 --- a/onelogin/client_test.go +++ b/onelogin/client_test.go @@ -27,7 +27,7 @@ func getTestServer(data string) *httptest.Server { var c = Client{} -var _ = log.NewLogger("panic", "", false) +var _, _ = log.SetupLogger("panic", "", false, true) func TestNewClient(t *testing.T) { for _, test := range []struct { diff --git a/onelogin/endpoints_test.go b/onelogin/endpoints_test.go index 37483935..cdd3e43a 100644 --- a/onelogin/endpoints_test.go +++ b/onelogin/endpoints_test.go @@ -12,7 +12,7 @@ import ( "github.com/allcloud-io/clisso/log" ) -var _ = log.NewLogger("panic", "", false) +var _, _ = log.SetupLogger("panic", "", false, true) func TestEndpoints_SetBase(t *testing.T) { for _, test := range []struct { diff --git a/onelogin/get.go b/onelogin/get.go index 8246cfeb..e359291d 100644 --- a/onelogin/get.go +++ b/onelogin/get.go @@ -20,7 +20,6 @@ import ( "github.com/allcloud-io/clisso/saml" "github.com/allcloud-io/clisso/spinner" "github.com/icza/gog" - "github.com/sirupsen/logrus" ) const ( @@ -43,7 +42,7 @@ var ( // Get gets temporary credentials for the given app. // TODO Move AWS logic outside this function. func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool) (*aws.Credentials, error) { - log.Log.WithFields(logrus.Fields{ + log.WithFields(log.Fields{ "app": app, "provider": provider, "pArn": pArn, @@ -72,7 +71,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool // Get OneLogin access token s.Start() - log.Log.Trace("Generating access token") + log.Trace("Generating access token") token, err := c.GenerateTokens(p.ClientID, p.ClientSecret) s.Stop() if err != nil { @@ -81,7 +80,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool user := p.Username if user == "" { - log.Log.Trace("No username provided") + log.Trace("No username provided") // Get credentials from the user fmt.Print("OneLogin username: ") _, err = fmt.Scanln(&user) @@ -106,10 +105,10 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool Subdomain: p.Subdomain, } - log.Log.WithFields(logrus.Fields{ + log.WithFields(log.Fields{ "UsernameOrEmail": user, // print password only in Trace Log Level - "Password": gog.If(log.Log.GetLevel() == logrus.TraceLevel, string(pass), ""), + "Password": gog.If(log.GetLevel() == log.TraceLevel, string(pass), ""), "AppId": a.ID, "Subdomain": p.Subdomain, }).Debug("Calling GenerateSamlAssertion") @@ -121,14 +120,14 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool return nil, fmt.Errorf("generating SAML assertion: %v", err) } - log.Log.WithField("Message", rSaml.Message).Debug("GenerateSamlAssertion is done") + log.WithField("Message", rSaml.Message).Debug("GenerateSamlAssertion is done") var rData string if rSaml.Message != "Success" { st := rSaml.StateToken devices := rSaml.Devices - log.Log.WithField("Devices", devices).Trace("Devices returned by GenerateSamlAssertion") + log.WithField("Devices", devices).Trace("Devices returned by GenerateSamlAssertion") device, err := getDevice(devices) if err != nil { return nil, fmt.Errorf("error getting devices: %s", err) @@ -148,7 +147,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool OtpToken: "", DoNotNotify: false, } - log.Log.WithFields(logrus.Fields{ + log.WithFields(log.Fields{ "AppId": a.ID, "DeviceId": device.DeviceID, "StateToken": st, @@ -173,7 +172,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool s.Start() for strings.Contains(rMfa.Message, "pending") && timeout > 0 { time.Sleep(time.Duration(MFAInterval) * time.Second) - log.Log.Trace("MFAInterval completed, calling VerifyFactor again") + log.Trace("MFAInterval completed, calling VerifyFactor again") rMfa, err = c.VerifyFactor(token, &pMfa) if err != nil { s.Stop() @@ -216,7 +215,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool } } rData = rMfa.Data - log.Log.Trace("Factor is verified") + log.Trace("Factor is verified") } else { rData = rSaml.Data } @@ -232,7 +231,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool if err != nil { if err.Error() == aws.ErrDurationExceeded { - log.Log.Warn(aws.DurationExceededMessage) + log.Warn(aws.DurationExceededMessage) s.Start() creds, err = aws.AssumeSAMLRole(arn.Provider, arn.Role, rData, awsRegion, 3600) s.Stop() @@ -255,7 +254,7 @@ func getDevice(devices []Device) (device *Device, err error) { } if len(devices) == 1 { - log.Log.Trace("Only one MFA device returned by Onelogin, automatically selecting it.") + log.Trace("Only one MFA device returned by Onelogin, automatically selecting it.") device = &Device{DeviceID: devices[0].DeviceID, DeviceType: devices[0].DeviceType} return } diff --git a/saml/saml.go b/saml/saml.go index f66b606e..356867dd 100644 --- a/saml/saml.go +++ b/saml/saml.go @@ -16,7 +16,6 @@ import ( "github.com/allcloud-io/clisso/log" "github.com/crewjam/saml" - "github.com/sirupsen/logrus" "github.com/spf13/viper" ) @@ -33,14 +32,14 @@ const idpRegex = `^arn:(?:aws|aws-cn):iam::\d+:saml-provider\/\S+$` func Get(data, pArn string) (a ARN, err error) { samlBody, err := decode(data) if err != nil { - log.Log.WithError(err).Error("Error decoding SAML assertion") + log.WithError(err).Error("Error decoding SAML assertion") return } x := new(saml.Response) err = xml.Unmarshal(samlBody, x) if err != nil { - log.Log.WithError(err).Error("Error parsing SAML assertion") + log.WithError(err).Error("Error parsing SAML assertion") return } @@ -68,10 +67,10 @@ func decode(in string) (b []byte, err error) { } func extractArns(stmts []saml.AttributeStatement, pArn string) []ARN { - log.Log.WithField("preferredARN", pArn).Trace("Extracting ARNs from SAML AttributeStatements") + log.WithField("preferredARN", pArn).Trace("Extracting ARNs from SAML AttributeStatements") // check for human readable ARN strings in config accounts := viper.GetStringMap("global.accounts") - log.Log.WithFields(accounts).Trace("Accounts loaded from config") + log.WithFields(accounts).Trace("Accounts loaded from config") arns := make([]ARN, 0) for _, stmt := range stmts { @@ -91,7 +90,7 @@ func extractArns(stmts []saml.AttributeStatement, pArn string) []ARN { components := strings.Split(strings.TrimSpace(av.Value), ",") if len(components) != 2 { // Wrong number of components - move on - log.Log.WithFields(logrus.Fields{ + log.WithFields(log.Fields{ "components": components, "length": len(components), "value": av.Value, @@ -102,7 +101,7 @@ func extractArns(stmts []saml.AttributeStatement, pArn string) []ARN { // people like to put spaces in there, AWS accepts them, let's remove them on our end too. components[0] = strings.TrimSpace(components[0]) components[1] = strings.TrimSpace(components[1]) - log.Log.WithField("components", components).Trace("ARN components extracted from SAML assertion") + log.WithField("components", components).Trace("ARN components extracted from SAML assertion") arn := ARN{} @@ -111,13 +110,13 @@ func extractArns(stmts []saml.AttributeStatement, pArn string) []ARN { // Otherwise it matches it with what is in the .clisso.yaml file if pArn != "" { if components[0] == pArn { - log.Log.Trace("Preferred ARN matches first component") + log.Trace("Preferred ARN matches first component") arn = ARN{components[0], components[1], ""} } else if components[1] == pArn { - log.Log.Trace("Preferred ARN matches second component") + log.Trace("Preferred ARN matches second component") arn = ARN{components[1], components[0], ""} } else { - log.Log.Trace("Preferred ARN does not match either component") + log.Trace("Preferred ARN does not match either component") continue } } else { @@ -126,20 +125,20 @@ func extractArns(stmts []saml.AttributeStatement, pArn string) []ARN { idp := regexp.MustCompile(idpRegex) if role.MatchString(components[0]) && idp.MatchString(components[1]) { - log.Log.Trace("First component is role, second component is IdP") + log.Trace("First component is role, second component is IdP") arn = ARN{components[0], components[1], ""} } else if role.MatchString(components[1]) && idp.MatchString(components[0]) { - log.Log.Trace("First component is IdP, second component is role") + log.Trace("First component is IdP, second component is role") arn = ARN{components[1], components[0], ""} } else { - log.Log.Trace("Neither component matches expected pattern") + log.Trace("Neither component matches expected pattern") continue } // Look up the human friendly name, if available if len(accounts) > 0 { ids := role.FindStringSubmatch(arn.Role) - log.Log.WithField("matches", ids).Trace("Role regex matches") + log.WithField("matches", ids).Trace("Role regex matches") // if the regex matches we should have 3 entries from the regex match // 1) the matching string @@ -147,7 +146,7 @@ func extractArns(stmts []saml.AttributeStatement, pArn string) []ARN { // 3) the match for Name // we want to match the Id to any accounts/roles in our config if len(ids) == 3 && accounts[ids[1]] != "" && accounts[ids[1]] != nil { - log.Log.Trace("Found human friendly name for account") + log.Trace("Found human friendly name for account") arn.Name = fmt.Sprintf("%s - %s", accounts[ids[1]].(string), ids[2]) } } @@ -160,7 +159,7 @@ func extractArns(stmts []saml.AttributeStatement, pArn string) []ARN { } } } - log.Log.Trace("No statements in SAML assertion or no ARNs found.") + log.Trace("No statements in SAML assertion or no ARNs found.") // Empty :( return arns } diff --git a/saml/saml_test.go b/saml/saml_test.go index 0debe5e0..a95d89b8 100644 --- a/saml/saml_test.go +++ b/saml/saml_test.go @@ -13,7 +13,7 @@ import ( "github.com/crewjam/saml" ) -var _ = log.NewLogger("panic", "", false) +var _, _ = log.SetupLogger("panic", "", false, true) func TestExtractArns(t *testing.T) { for _, test := range []struct { diff --git a/sample_config.yaml b/sample_config.yaml index 94db7786..4ba7cf81 100644 --- a/sample_config.yaml +++ b/sample_config.yaml @@ -21,8 +21,8 @@ global: "2222222222222": QA providers: sample-onelogin-provider: - client-id: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx - client-secret: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx + client-id: abcdef-sample-client-id-ghijkl + client-secret: 123456-sample-client-secret-789012 subdomain: sample type: onelogin username: example@example.com diff --git a/spinner/spinner_unix.go b/spinner/spinner_unix.go index 3a2e9e9f..07031ccd 100644 --- a/spinner/spinner_unix.go +++ b/spinner/spinner_unix.go @@ -14,11 +14,10 @@ import ( "github.com/allcloud-io/clisso/log" "github.com/briandowns/spinner" - "github.com/sirupsen/logrus" ) func new(interactive bool) SpinnerWrapper { - if log.Log.GetLevel() >= logrus.DebugLevel || !interactive { + if log.GetLevel() >= log.DebugLevel || !interactive { return &noopSpinner{} } return spinner.New(spinner.CharSets[14], 50*time.Millisecond)