Skip to content

Commit

Permalink
Prioritise source_profile over sso config
Browse files Browse the repository at this point in the history
  • Loading branch information
mtibben committed Mar 20, 2023
1 parent 99d5e91 commit 9dc5bca
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 29 deletions.
71 changes: 42 additions & 29 deletions vault/vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ func FormatKeyForDisplay(k string) string {
return fmt.Sprintf("****************%s", k[len(k)-4:])
}

func isMasterCredentialsProvider(credsProvider aws.CredentialsProvider) bool {
_, ok := credsProvider.(*KeyringProvider)
return ok
}

// NewMasterCredentialsProvider creates a provider for the master credentials
func NewMasterCredentialsProvider(k *CredentialKeyring, credentialsName string) *KeyringProvider {
return &KeyringProvider{k, credentialsName}
Expand Down Expand Up @@ -243,52 +248,60 @@ func (t *TempCredentialsCreator) getSourceCreds(config *ProfileConfig, hasStored
return nil, fmt.Errorf("profile %s: credentials missing", config.ProfileName)
}

func (t *TempCredentialsCreator) GetProviderForProfile(config *ProfileConfig) (aws.CredentialsProvider, error) {
hasStoredCredentials, err := t.Keyring.Has(config.ProfileName)
func (t *TempCredentialsCreator) getSourceCredWithSession(config *ProfileConfig, hasStoredCredentials bool) (sourcecredsProvider aws.CredentialsProvider, err error) {
sourcecredsProvider, err = t.getSourceCreds(config, hasStoredCredentials)
if err != nil {
return nil, err
}

if !hasStoredCredentials {
if config.HasSSOStartURL() {
log.Printf("profile %s: using SSO role credentials", config.ProfileName)
return NewSSORoleCredentialsProvider(t.Keyring.Keyring, config, !t.DisableCache)
}

if config.HasWebIdentity() {
log.Printf("profile %s: using web identity", config.ProfileName)
return NewAssumeRoleWithWebIdentityProvider(t.Keyring.Keyring, config, !t.DisableCache)
if config.HasRole() {
isMfaChained := config.MfaSerial != "" && config.MfaSerial == t.chainedMfa
if isMfaChained {
config.MfaSerial = ""
}
log.Printf("profile %s: using AssumeRole %s", config.ProfileName, mfaDetails(isMfaChained, config))
return NewAssumeRoleProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache)
}

if config.HasCredentialProcess() {
log.Printf("profile %s: using credential process", config.ProfileName)
return NewCredentialProcessProvider(t.Keyring.Keyring, config, !t.DisableCache)
if isMasterCredentialsProvider(sourcecredsProvider) {
canUseGetSessionToken, reason := t.canUseGetSessionToken(config)
if canUseGetSessionToken {
t.chainedMfa = config.MfaSerial
log.Printf("profile %s: using GetSessionToken %s", config.ProfileName, mfaDetails(false, config))
return NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache)
}
log.Printf("profile %s: skipping GetSessionToken because %s", config.ProfileName, reason)
}

sourcecredsProvider, err := t.getSourceCreds(config, hasStoredCredentials)
return sourcecredsProvider, nil
}

func (t *TempCredentialsCreator) GetProviderForProfile(config *ProfileConfig) (aws.CredentialsProvider, error) {
hasStoredCredentials, err := t.Keyring.Has(config.ProfileName)
if err != nil {
return nil, err
}

if config.HasRole() {
isMfaChained := config.MfaSerial != "" && config.MfaSerial == t.chainedMfa
if isMfaChained {
config.MfaSerial = ""
}
log.Printf("profile %s: using AssumeRole %s", config.ProfileName, mfaDetails(isMfaChained, config))
return NewAssumeRoleProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache)
if hasStoredCredentials || config.HasSourceProfile() {
return t.getSourceCredWithSession(config, hasStoredCredentials)
}

canUseGetSessionToken, reason := t.canUseGetSessionToken(config)
if canUseGetSessionToken {
t.chainedMfa = config.MfaSerial
log.Printf("profile %s: using GetSessionToken %s", config.ProfileName, mfaDetails(false, config))
return NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache)
if config.HasSSOStartURL() {
log.Printf("profile %s: using SSO role credentials", config.ProfileName)
return NewSSORoleCredentialsProvider(t.Keyring.Keyring, config, !t.DisableCache)
}

log.Printf("profile %s: skipping GetSessionToken because %s", config.ProfileName, reason)
return sourcecredsProvider, nil
if config.HasWebIdentity() {
log.Printf("profile %s: using web identity", config.ProfileName)
return NewAssumeRoleWithWebIdentityProvider(t.Keyring.Keyring, config, !t.DisableCache)
}

if config.HasCredentialProcess() {
log.Printf("profile %s: using credential process", config.ProfileName)
return NewCredentialProcessProvider(t.Keyring.Keyring, config, !t.DisableCache)
}

return nil, fmt.Errorf("profile %s: credentials missing", config.ProfileName)
}

// canUseGetSessionToken determines if GetSessionToken should be used, and if not returns a reason
Expand Down
60 changes: 60 additions & 0 deletions vault/vault_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package vault_test

import (
"os"
"testing"

"github.com/99designs/aws-vault/v7/vault"
"github.com/99designs/keyring"
)

func TestIssue1195(t *testing.T) {
f := newConfigFile(t, []byte(`
[profile test]
source_profile=dev
region=ap-northeast-2
[profile dev]
sso_session=common
sso_account_id=2160xxxx
sso_role_name=AdministratorAccess
region=ap-northeast-2
output=json
[default]
sso_session=common
sso_account_id=3701xxxx
sso_role_name=AdministratorAccess
region=ap-northeast-2
output=json
[sso-session common]
sso_start_url=https://xxxx.awsapps.com/start
sso_region=ap-northeast-2
sso_registration_scopes=sso:account:access
`))
defer os.Remove(f)
configFile, err := vault.LoadConfig(f)
if err != nil {
t.Fatal(err)
}
configLoader := &vault.ConfigLoader{File: configFile, ActiveProfile: "test"}
config, err := configLoader.GetProfileConfig("test")
if err != nil {
t.Fatalf("Should have found a profile: %v", err)
}

ckr := &vault.CredentialKeyring{Keyring: keyring.NewArrayKeyring([]keyring.Item{})}
p, err := vault.NewTempCredentialsProvider(config, ckr, true, true)
if err != nil {
t.Fatal(err)
}

ssoProvider, ok := p.(*vault.SSORoleCredentialsProvider)
if !ok {
t.Fatalf("Expected SSORoleCredentialsProvider, got %T", p)
}
if ssoProvider.AccountID != "2160xxxx" {
t.Fatalf("Expected AccountID to be 2160xxxx, got %s", ssoProvider.AccountID)
}
}

0 comments on commit 9dc5bca

Please sign in to comment.