From 9dc5bca94f1359aec52cdd6e2e0df5e701db182e Mon Sep 17 00:00:00 2001 From: Michael Tibben Date: Mon, 20 Mar 2023 17:08:55 +1100 Subject: [PATCH] Prioritise source_profile over sso config --- vault/vault.go | 71 +++++++++++++++++++++++++++------------------ vault/vault_test.go | 60 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 29 deletions(-) create mode 100644 vault/vault_test.go diff --git a/vault/vault.go b/vault/vault.go index 9f3d4a6e7..9fd27fe2a 100644 --- a/vault/vault.go +++ b/vault/vault.go @@ -41,6 +41,11 @@ func FormatKeyForDisplay(k string) string { return fmt.Sprintf("****************%s", k[len(k)-4:]) } +func isMasterCredentialsProvider(credsProvider aws.CredentialsProvider) bool { + _, ok := credsProvider.(*KeyringProvider) + return ok +} + // NewMasterCredentialsProvider creates a provider for the master credentials func NewMasterCredentialsProvider(k *CredentialKeyring, credentialsName string) *KeyringProvider { return &KeyringProvider{k, credentialsName} @@ -243,52 +248,60 @@ func (t *TempCredentialsCreator) getSourceCreds(config *ProfileConfig, hasStored return nil, fmt.Errorf("profile %s: credentials missing", config.ProfileName) } -func (t *TempCredentialsCreator) GetProviderForProfile(config *ProfileConfig) (aws.CredentialsProvider, error) { - hasStoredCredentials, err := t.Keyring.Has(config.ProfileName) +func (t *TempCredentialsCreator) getSourceCredWithSession(config *ProfileConfig, hasStoredCredentials bool) (sourcecredsProvider aws.CredentialsProvider, err error) { + sourcecredsProvider, err = t.getSourceCreds(config, hasStoredCredentials) if err != nil { return nil, err } - if !hasStoredCredentials { - if config.HasSSOStartURL() { - log.Printf("profile %s: using SSO role credentials", config.ProfileName) - return NewSSORoleCredentialsProvider(t.Keyring.Keyring, config, !t.DisableCache) - } - - if config.HasWebIdentity() { - log.Printf("profile %s: using web identity", config.ProfileName) - return NewAssumeRoleWithWebIdentityProvider(t.Keyring.Keyring, config, !t.DisableCache) + if config.HasRole() { + isMfaChained := config.MfaSerial != "" && config.MfaSerial == t.chainedMfa + if isMfaChained { + config.MfaSerial = "" } + log.Printf("profile %s: using AssumeRole %s", config.ProfileName, mfaDetails(isMfaChained, config)) + return NewAssumeRoleProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache) + } - if config.HasCredentialProcess() { - log.Printf("profile %s: using credential process", config.ProfileName) - return NewCredentialProcessProvider(t.Keyring.Keyring, config, !t.DisableCache) + if isMasterCredentialsProvider(sourcecredsProvider) { + canUseGetSessionToken, reason := t.canUseGetSessionToken(config) + if canUseGetSessionToken { + t.chainedMfa = config.MfaSerial + log.Printf("profile %s: using GetSessionToken %s", config.ProfileName, mfaDetails(false, config)) + return NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache) } + log.Printf("profile %s: skipping GetSessionToken because %s", config.ProfileName, reason) } - sourcecredsProvider, err := t.getSourceCreds(config, hasStoredCredentials) + return sourcecredsProvider, nil +} + +func (t *TempCredentialsCreator) GetProviderForProfile(config *ProfileConfig) (aws.CredentialsProvider, error) { + hasStoredCredentials, err := t.Keyring.Has(config.ProfileName) if err != nil { return nil, err } - if config.HasRole() { - isMfaChained := config.MfaSerial != "" && config.MfaSerial == t.chainedMfa - if isMfaChained { - config.MfaSerial = "" - } - log.Printf("profile %s: using AssumeRole %s", config.ProfileName, mfaDetails(isMfaChained, config)) - return NewAssumeRoleProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache) + if hasStoredCredentials || config.HasSourceProfile() { + return t.getSourceCredWithSession(config, hasStoredCredentials) } - canUseGetSessionToken, reason := t.canUseGetSessionToken(config) - if canUseGetSessionToken { - t.chainedMfa = config.MfaSerial - log.Printf("profile %s: using GetSessionToken %s", config.ProfileName, mfaDetails(false, config)) - return NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache) + if config.HasSSOStartURL() { + log.Printf("profile %s: using SSO role credentials", config.ProfileName) + return NewSSORoleCredentialsProvider(t.Keyring.Keyring, config, !t.DisableCache) } - log.Printf("profile %s: skipping GetSessionToken because %s", config.ProfileName, reason) - return sourcecredsProvider, nil + if config.HasWebIdentity() { + log.Printf("profile %s: using web identity", config.ProfileName) + return NewAssumeRoleWithWebIdentityProvider(t.Keyring.Keyring, config, !t.DisableCache) + } + + if config.HasCredentialProcess() { + log.Printf("profile %s: using credential process", config.ProfileName) + return NewCredentialProcessProvider(t.Keyring.Keyring, config, !t.DisableCache) + } + + return nil, fmt.Errorf("profile %s: credentials missing", config.ProfileName) } // canUseGetSessionToken determines if GetSessionToken should be used, and if not returns a reason diff --git a/vault/vault_test.go b/vault/vault_test.go new file mode 100644 index 000000000..e4c7466c3 --- /dev/null +++ b/vault/vault_test.go @@ -0,0 +1,60 @@ +package vault_test + +import ( + "os" + "testing" + + "github.com/99designs/aws-vault/v7/vault" + "github.com/99designs/keyring" +) + +func TestIssue1195(t *testing.T) { + f := newConfigFile(t, []byte(` +[profile test] +source_profile=dev +region=ap-northeast-2 + +[profile dev] +sso_session=common +sso_account_id=2160xxxx +sso_role_name=AdministratorAccess +region=ap-northeast-2 +output=json + +[default] +sso_session=common +sso_account_id=3701xxxx +sso_role_name=AdministratorAccess +region=ap-northeast-2 +output=json + +[sso-session common] +sso_start_url=https://xxxx.awsapps.com/start +sso_region=ap-northeast-2 +sso_registration_scopes=sso:account:access +`)) + defer os.Remove(f) + configFile, err := vault.LoadConfig(f) + if err != nil { + t.Fatal(err) + } + configLoader := &vault.ConfigLoader{File: configFile, ActiveProfile: "test"} + config, err := configLoader.GetProfileConfig("test") + if err != nil { + t.Fatalf("Should have found a profile: %v", err) + } + + ckr := &vault.CredentialKeyring{Keyring: keyring.NewArrayKeyring([]keyring.Item{})} + p, err := vault.NewTempCredentialsProvider(config, ckr, true, true) + if err != nil { + t.Fatal(err) + } + + ssoProvider, ok := p.(*vault.SSORoleCredentialsProvider) + if !ok { + t.Fatalf("Expected SSORoleCredentialsProvider, got %T", p) + } + if ssoProvider.AccountID != "2160xxxx" { + t.Fatalf("Expected AccountID to be 2160xxxx, got %s", ssoProvider.AccountID) + } +}