Skip to content

Commit

Permalink
Use DI for UseSession
Browse files Browse the repository at this point in the history
  • Loading branch information
mtibben committed Mar 3, 2023
1 parent 382f22e commit f2527e2
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 52 deletions.
4 changes: 1 addition & 3 deletions cli/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,12 @@ func ExecCommand(input ExecCommandInput, f *vault.ConfigFile, keyring keyring.Ke
return 0, err
}

vault.UseSession = !input.NoSession

config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).LoadFromProfile(input.ProfileName)
if err != nil {
return 0, fmt.Errorf("Error loading config: %w", err)
}

credsProvider, err := vault.NewTempCredentialsProvider(config, &vault.CredentialKeyring{Keyring: keyring})
credsProvider, err := vault.NewTempCredentialsProvider(config, &vault.CredentialKeyring{Keyring: keyring}, !input.NoSession)
if err != nil {
return 0, fmt.Errorf("Error getting temporary credentials: %w", err)
}
Expand Down
4 changes: 1 addition & 3 deletions cli/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,13 @@ func ExportCommand(input ExportCommandInput, f *vault.ConfigFile, keyring keyrin
return fmt.Errorf("in an existing aws-vault subshell; 'exit' from the subshell or unset AWS_VAULT to force")
}

vault.UseSession = !input.NoSession

config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).LoadFromProfile(input.ProfileName)
if err != nil {
return fmt.Errorf("Error loading config: %w", err)
}

ckr := &vault.CredentialKeyring{Keyring: keyring}
credsProvider, err := vault.NewTempCredentialsProvider(config, ckr)
credsProvider, err := vault.NewTempCredentialsProvider(config, ckr, !input.NoSession)
if err != nil {
return fmt.Errorf("Error getting temporary credentials: %w", err)
}
Expand Down
4 changes: 1 addition & 3 deletions cli/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ func ConfigureLoginCommand(app *kingpin.Application, a *AwsVault) {
}

func LoginCommand(input LoginCommandInput, f *vault.ConfigFile, keyring keyring.Keyring) error {
vault.UseSession = !input.NoSession

config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).LoadFromProfile(input.ProfileName)
if err != nil {
return fmt.Errorf("Error loading config: %w", err)
Expand All @@ -108,7 +106,7 @@ func LoginCommand(input LoginCommandInput, f *vault.ConfigFile, keyring keyring.
ckr := &vault.CredentialKeyring{Keyring: keyring}
if config.HasRole() || config.HasSSOStartURL() || config.HasCredentialProcess() || config.HasWebIdentity() {
// If AssumeRole or sso.GetRoleCredentials isn't used, GetFederationToken has to be used for IAM credentials
credsProvider, err = vault.NewTempCredentialsProvider(config, ckr)
credsProvider, err = vault.NewTempCredentialsProvider(config, ckr, !input.NoSession)
} else {
credsProvider, err = vault.NewFederationTokenCredentialsProvider(context.TODO(), input.ProfileName, ckr, config)
}
Expand Down
3 changes: 1 addition & 2 deletions cli/rotate.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ func ConfigureRotateCommand(app *kingpin.Application, a *AwsVault) {

func RotateCommand(input RotateCommandInput, f *vault.ConfigFile, keyring keyring.Keyring) error {
// Can't disable sessions completely, might need to use session for MFA-Protected API Access
vault.UseSession = !input.NoSession
vault.UseSessionCache = false

configLoader := vault.NewConfigLoader(input.Config, f, input.ProfileName)
Expand Down Expand Up @@ -88,7 +87,7 @@ func RotateCommand(input RotateCommandInput, f *vault.ConfigFile, keyring keyrin
if input.NoSession {
credsProvider = vault.NewMasterCredentialsProvider(ckr, config.ProfileName)
} else {
credsProvider, err = vault.NewTempCredentialsProvider(config, ckr)
credsProvider, err = vault.NewTempCredentialsProvider(config, ckr, !input.NoSession)
if err != nil {
return fmt.Errorf("Error getting temporary credentials: %w", err)
}
Expand Down
30 changes: 0 additions & 30 deletions vault/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ const (
roleChainingMaximumDuration = 1 * time.Hour
)

// UseSession will disable the use of GetSessionToken when set to false
var UseSession = true

func init() {
ini.PrettyFormat = false
}
Expand Down Expand Up @@ -676,33 +673,6 @@ func (c *Config) HasCredentialProcess() bool {
return c.CredentialProcess != ""
}

// CanUseGetSessionToken determines if GetSessionToken should be used, and if not returns a reason
func (c *Config) CanUseGetSessionToken() (bool, string) {
if !UseSession {
return false, "sessions are disabled"
}

if c.IsChained() {
if !c.ChainedFromProfile.HasMfaSerial() {
return false, fmt.Sprintf("profile '%s' has no MFA serial defined", c.ChainedFromProfile.ProfileName)
}

if !c.HasMfaSerial() && c.ChainedFromProfile.HasMfaSerial() {
return false, fmt.Sprintf("profile '%s' has no MFA serial defined", c.ProfileName)
}

if c.ChainedFromProfile.MfaSerial != c.MfaSerial {
return false, fmt.Sprintf("MFA serial doesn't match profile '%s'", c.ChainedFromProfile.ProfileName)
}

if c.ChainedFromProfile.AssumeRoleDuration > roleChainingMaximumDuration {
return false, fmt.Sprintf("duration %s in profile '%s' is greater than the AWS maximum %s for chaining MFA", c.ChainedFromProfile.AssumeRoleDuration, c.ChainedFromProfile.ProfileName, roleChainingMaximumDuration)
}
}

return true, ""
}

func (c *Config) GetSessionTokenDuration() time.Duration {
if c.IsChained() {
return c.ChainedGetSessionTokenDuration
Expand Down
53 changes: 42 additions & 11 deletions vault/vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,10 @@ func NewCredentialProcessProvider(k keyring.Keyring, config *Config) (aws.Creden
}

type tempCredsCreator struct {
keyring *CredentialKeyring
// UseSession will disable the use of GetSessionToken when set to false
UseSession bool
Keyring *CredentialKeyring

chainedMfa string
}

Expand All @@ -197,14 +200,14 @@ func (t *tempCredsCreator) getSourceCreds(config *Config) (sourcecredsProvider a
return t.GetProviderForProfile(config.SourceProfile)
}

hasStoredCredentials, err := t.keyring.Has(config.ProfileName)
hasStoredCredentials, err := t.Keyring.Has(config.ProfileName)
if err != nil {
return nil, err
}

if hasStoredCredentials {
log.Printf("profile %s: using stored credentials", config.ProfileName)
return NewMasterCredentialsProvider(t.keyring, config.ProfileName), nil
return NewMasterCredentialsProvider(t.Keyring, config.ProfileName), nil
}

return nil, fmt.Errorf("profile %s: credentials missing", config.ProfileName)
Expand All @@ -217,17 +220,17 @@ func (t *tempCredsCreator) GetProviderForProfile(config *Config) (aws.Credential

if config.HasSSOStartURL() {
log.Printf("profile %s: using SSO role credentials", config.ProfileName)
return NewSSORoleCredentialsProvider(t.keyring.Keyring, config)
return NewSSORoleCredentialsProvider(t.Keyring.Keyring, config)
}

if config.HasWebIdentity() {
log.Printf("profile %s: using web identity", config.ProfileName)
return NewAssumeRoleWithWebIdentityProvider(t.keyring.Keyring, config)
return NewAssumeRoleWithWebIdentityProvider(t.Keyring.Keyring, config)
}

if config.HasCredentialProcess() {
log.Printf("profile %s: using credential process", config.ProfileName)
return NewCredentialProcessProvider(t.keyring.Keyring, config)
return NewCredentialProcessProvider(t.Keyring.Keyring, config)
}

sourcecredsProvider, err := t.getSourceCreds(config)
Expand All @@ -241,20 +244,47 @@ func (t *tempCredsCreator) GetProviderForProfile(config *Config) (aws.Credential
config.MfaSerial = ""
}
log.Printf("profile %s: using AssumeRole %s", config.ProfileName, mfaDetails(isMfaChained, config))
return NewAssumeRoleProvider(sourcecredsProvider, t.keyring.Keyring, config)
return NewAssumeRoleProvider(sourcecredsProvider, t.Keyring.Keyring, config)
}

canUseGetSessionToken, reason := config.CanUseGetSessionToken()
canUseGetSessionToken, reason := t.canUseGetSessionToken(config)
if canUseGetSessionToken {
t.chainedMfa = config.MfaSerial
log.Printf("profile %s: using GetSessionToken %s", config.ProfileName, mfaDetails(false, config))
return NewSessionTokenProvider(sourcecredsProvider, t.keyring.Keyring, config)
return NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config)
}

log.Printf("profile %s: skipping GetSessionToken because %s", config.ProfileName, reason)
return sourcecredsProvider, nil
}

// canUseGetSessionToken determines if GetSessionToken should be used, and if not returns a reason
func (t *tempCredsCreator) canUseGetSessionToken(c *Config) (bool, string) {
if !t.UseSession {
return false, "sessions are disabled"
}

if c.IsChained() {
if !c.ChainedFromProfile.HasMfaSerial() {
return false, fmt.Sprintf("profile '%s' has no MFA serial defined", c.ChainedFromProfile.ProfileName)
}

if !c.HasMfaSerial() && c.ChainedFromProfile.HasMfaSerial() {
return false, fmt.Sprintf("profile '%s' has no MFA serial defined", c.ProfileName)
}

if c.ChainedFromProfile.MfaSerial != c.MfaSerial {
return false, fmt.Sprintf("MFA serial doesn't match profile '%s'", c.ChainedFromProfile.ProfileName)
}

if c.ChainedFromProfile.AssumeRoleDuration > roleChainingMaximumDuration {
return false, fmt.Sprintf("duration %s in profile '%s' is greater than the AWS maximum %s for chaining MFA", c.ChainedFromProfile.AssumeRoleDuration, c.ChainedFromProfile.ProfileName, roleChainingMaximumDuration)
}
}

return true, ""
}

func mfaDetails(mfaChained bool, config *Config) string {
if mfaChained {
return "(chained MFA)"
Expand All @@ -266,9 +296,10 @@ func mfaDetails(mfaChained bool, config *Config) string {
}

// NewTempCredentialsProvider creates a credential provider for the given config
func NewTempCredentialsProvider(config *Config, keyring *CredentialKeyring) (aws.CredentialsProvider, error) {
func NewTempCredentialsProvider(config *Config, keyring *CredentialKeyring, useSession bool) (aws.CredentialsProvider, error) {
t := tempCredsCreator{
keyring: keyring,
Keyring: keyring,
UseSession: useSession,
}
return t.GetProviderForProfile(config)
}
Expand Down

0 comments on commit f2527e2

Please sign in to comment.