From f2527e2f7a4b09a7f8f0db1b22ec4a3f8c893acc Mon Sep 17 00:00:00 2001 From: Michael Tibben Date: Sat, 4 Mar 2023 10:15:55 +1100 Subject: [PATCH] Use DI for UseSession --- cli/exec.go | 4 +--- cli/export.go | 4 +--- cli/login.go | 4 +--- cli/rotate.go | 3 +-- vault/config.go | 30 ---------------------------- vault/vault.go | 53 +++++++++++++++++++++++++++++++++++++++---------- 6 files changed, 46 insertions(+), 52 deletions(-) diff --git a/cli/exec.go b/cli/exec.go index 18ac423fe..3a15b3f76 100644 --- a/cli/exec.go +++ b/cli/exec.go @@ -167,14 +167,12 @@ func ExecCommand(input ExecCommandInput, f *vault.ConfigFile, keyring keyring.Ke return 0, err } - vault.UseSession = !input.NoSession - config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).LoadFromProfile(input.ProfileName) if err != nil { return 0, fmt.Errorf("Error loading config: %w", err) } - credsProvider, err := vault.NewTempCredentialsProvider(config, &vault.CredentialKeyring{Keyring: keyring}) + credsProvider, err := vault.NewTempCredentialsProvider(config, &vault.CredentialKeyring{Keyring: keyring}, !input.NoSession) if err != nil { return 0, fmt.Errorf("Error getting temporary credentials: %w", err) } diff --git a/cli/export.go b/cli/export.go index 317913551..c09409c76 100644 --- a/cli/export.go +++ b/cli/export.go @@ -90,15 +90,13 @@ func ExportCommand(input ExportCommandInput, f *vault.ConfigFile, keyring keyrin return fmt.Errorf("in an existing aws-vault subshell; 'exit' from the subshell or unset AWS_VAULT to force") } - vault.UseSession = !input.NoSession - config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).LoadFromProfile(input.ProfileName) if err != nil { return fmt.Errorf("Error loading config: %w", err) } ckr := &vault.CredentialKeyring{Keyring: keyring} - credsProvider, err := vault.NewTempCredentialsProvider(config, ckr) + credsProvider, err := vault.NewTempCredentialsProvider(config, ckr, !input.NoSession) if err != nil { return fmt.Errorf("Error getting temporary credentials: %w", err) } diff --git a/cli/login.go b/cli/login.go index df9a9eac7..c2bc16bd9 100644 --- a/cli/login.go +++ b/cli/login.go @@ -81,8 +81,6 @@ func ConfigureLoginCommand(app *kingpin.Application, a *AwsVault) { } func LoginCommand(input LoginCommandInput, f *vault.ConfigFile, keyring keyring.Keyring) error { - vault.UseSession = !input.NoSession - config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).LoadFromProfile(input.ProfileName) if err != nil { return fmt.Errorf("Error loading config: %w", err) @@ -108,7 +106,7 @@ func LoginCommand(input LoginCommandInput, f *vault.ConfigFile, keyring keyring. ckr := &vault.CredentialKeyring{Keyring: keyring} if config.HasRole() || config.HasSSOStartURL() || config.HasCredentialProcess() || config.HasWebIdentity() { // If AssumeRole or sso.GetRoleCredentials isn't used, GetFederationToken has to be used for IAM credentials - credsProvider, err = vault.NewTempCredentialsProvider(config, ckr) + credsProvider, err = vault.NewTempCredentialsProvider(config, ckr, !input.NoSession) } else { credsProvider, err = vault.NewFederationTokenCredentialsProvider(context.TODO(), input.ProfileName, ckr, config) } diff --git a/cli/rotate.go b/cli/rotate.go index 0a2b14578..23752e908 100644 --- a/cli/rotate.go +++ b/cli/rotate.go @@ -52,7 +52,6 @@ func ConfigureRotateCommand(app *kingpin.Application, a *AwsVault) { func RotateCommand(input RotateCommandInput, f *vault.ConfigFile, keyring keyring.Keyring) error { // Can't disable sessions completely, might need to use session for MFA-Protected API Access - vault.UseSession = !input.NoSession vault.UseSessionCache = false configLoader := vault.NewConfigLoader(input.Config, f, input.ProfileName) @@ -88,7 +87,7 @@ func RotateCommand(input RotateCommandInput, f *vault.ConfigFile, keyring keyrin if input.NoSession { credsProvider = vault.NewMasterCredentialsProvider(ckr, config.ProfileName) } else { - credsProvider, err = vault.NewTempCredentialsProvider(config, ckr) + credsProvider, err = vault.NewTempCredentialsProvider(config, ckr, !input.NoSession) if err != nil { return fmt.Errorf("Error getting temporary credentials: %w", err) } diff --git a/vault/config.go b/vault/config.go index 035c6dd50..37b54064e 100644 --- a/vault/config.go +++ b/vault/config.go @@ -23,9 +23,6 @@ const ( roleChainingMaximumDuration = 1 * time.Hour ) -// UseSession will disable the use of GetSessionToken when set to false -var UseSession = true - func init() { ini.PrettyFormat = false } @@ -676,33 +673,6 @@ func (c *Config) HasCredentialProcess() bool { return c.CredentialProcess != "" } -// CanUseGetSessionToken determines if GetSessionToken should be used, and if not returns a reason -func (c *Config) CanUseGetSessionToken() (bool, string) { - if !UseSession { - return false, "sessions are disabled" - } - - if c.IsChained() { - if !c.ChainedFromProfile.HasMfaSerial() { - return false, fmt.Sprintf("profile '%s' has no MFA serial defined", c.ChainedFromProfile.ProfileName) - } - - if !c.HasMfaSerial() && c.ChainedFromProfile.HasMfaSerial() { - return false, fmt.Sprintf("profile '%s' has no MFA serial defined", c.ProfileName) - } - - if c.ChainedFromProfile.MfaSerial != c.MfaSerial { - return false, fmt.Sprintf("MFA serial doesn't match profile '%s'", c.ChainedFromProfile.ProfileName) - } - - if c.ChainedFromProfile.AssumeRoleDuration > roleChainingMaximumDuration { - return false, fmt.Sprintf("duration %s in profile '%s' is greater than the AWS maximum %s for chaining MFA", c.ChainedFromProfile.AssumeRoleDuration, c.ChainedFromProfile.ProfileName, roleChainingMaximumDuration) - } - } - - return true, "" -} - func (c *Config) GetSessionTokenDuration() time.Duration { if c.IsChained() { return c.ChainedGetSessionTokenDuration diff --git a/vault/vault.go b/vault/vault.go index 7fbea53d1..17ddad135 100644 --- a/vault/vault.go +++ b/vault/vault.go @@ -187,7 +187,10 @@ func NewCredentialProcessProvider(k keyring.Keyring, config *Config) (aws.Creden } type tempCredsCreator struct { - keyring *CredentialKeyring + // UseSession will disable the use of GetSessionToken when set to false + UseSession bool + Keyring *CredentialKeyring + chainedMfa string } @@ -197,14 +200,14 @@ func (t *tempCredsCreator) getSourceCreds(config *Config) (sourcecredsProvider a return t.GetProviderForProfile(config.SourceProfile) } - hasStoredCredentials, err := t.keyring.Has(config.ProfileName) + hasStoredCredentials, err := t.Keyring.Has(config.ProfileName) if err != nil { return nil, err } if hasStoredCredentials { log.Printf("profile %s: using stored credentials", config.ProfileName) - return NewMasterCredentialsProvider(t.keyring, config.ProfileName), nil + return NewMasterCredentialsProvider(t.Keyring, config.ProfileName), nil } return nil, fmt.Errorf("profile %s: credentials missing", config.ProfileName) @@ -217,17 +220,17 @@ func (t *tempCredsCreator) GetProviderForProfile(config *Config) (aws.Credential if config.HasSSOStartURL() { log.Printf("profile %s: using SSO role credentials", config.ProfileName) - return NewSSORoleCredentialsProvider(t.keyring.Keyring, config) + return NewSSORoleCredentialsProvider(t.Keyring.Keyring, config) } if config.HasWebIdentity() { log.Printf("profile %s: using web identity", config.ProfileName) - return NewAssumeRoleWithWebIdentityProvider(t.keyring.Keyring, config) + return NewAssumeRoleWithWebIdentityProvider(t.Keyring.Keyring, config) } if config.HasCredentialProcess() { log.Printf("profile %s: using credential process", config.ProfileName) - return NewCredentialProcessProvider(t.keyring.Keyring, config) + return NewCredentialProcessProvider(t.Keyring.Keyring, config) } sourcecredsProvider, err := t.getSourceCreds(config) @@ -241,20 +244,47 @@ func (t *tempCredsCreator) GetProviderForProfile(config *Config) (aws.Credential config.MfaSerial = "" } log.Printf("profile %s: using AssumeRole %s", config.ProfileName, mfaDetails(isMfaChained, config)) - return NewAssumeRoleProvider(sourcecredsProvider, t.keyring.Keyring, config) + return NewAssumeRoleProvider(sourcecredsProvider, t.Keyring.Keyring, config) } - canUseGetSessionToken, reason := config.CanUseGetSessionToken() + canUseGetSessionToken, reason := t.canUseGetSessionToken(config) if canUseGetSessionToken { t.chainedMfa = config.MfaSerial log.Printf("profile %s: using GetSessionToken %s", config.ProfileName, mfaDetails(false, config)) - return NewSessionTokenProvider(sourcecredsProvider, t.keyring.Keyring, config) + return NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config) } log.Printf("profile %s: skipping GetSessionToken because %s", config.ProfileName, reason) return sourcecredsProvider, nil } +// canUseGetSessionToken determines if GetSessionToken should be used, and if not returns a reason +func (t *tempCredsCreator) canUseGetSessionToken(c *Config) (bool, string) { + if !t.UseSession { + return false, "sessions are disabled" + } + + if c.IsChained() { + if !c.ChainedFromProfile.HasMfaSerial() { + return false, fmt.Sprintf("profile '%s' has no MFA serial defined", c.ChainedFromProfile.ProfileName) + } + + if !c.HasMfaSerial() && c.ChainedFromProfile.HasMfaSerial() { + return false, fmt.Sprintf("profile '%s' has no MFA serial defined", c.ProfileName) + } + + if c.ChainedFromProfile.MfaSerial != c.MfaSerial { + return false, fmt.Sprintf("MFA serial doesn't match profile '%s'", c.ChainedFromProfile.ProfileName) + } + + if c.ChainedFromProfile.AssumeRoleDuration > roleChainingMaximumDuration { + return false, fmt.Sprintf("duration %s in profile '%s' is greater than the AWS maximum %s for chaining MFA", c.ChainedFromProfile.AssumeRoleDuration, c.ChainedFromProfile.ProfileName, roleChainingMaximumDuration) + } + } + + return true, "" +} + func mfaDetails(mfaChained bool, config *Config) string { if mfaChained { return "(chained MFA)" @@ -266,9 +296,10 @@ func mfaDetails(mfaChained bool, config *Config) string { } // NewTempCredentialsProvider creates a credential provider for the given config -func NewTempCredentialsProvider(config *Config, keyring *CredentialKeyring) (aws.CredentialsProvider, error) { +func NewTempCredentialsProvider(config *Config, keyring *CredentialKeyring, useSession bool) (aws.CredentialsProvider, error) { t := tempCredsCreator{ - keyring: keyring, + Keyring: keyring, + UseSession: useSession, } return t.GetProviderForProfile(config) }