Skip to content

Commit

Permalink
Merge pull request #1174 from 99designs/use-di-for-use-session-cache
Browse files Browse the repository at this point in the history
Use DI for useSessionCache
  • Loading branch information
mtibben committed Mar 5, 2023
2 parents afa09bf + 2a95016 commit ec5e53c
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 29 deletions.
2 changes: 1 addition & 1 deletion cli/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func ExecCommand(input ExecCommandInput, f *vault.ConfigFile, keyring keyring.Ke
return 0, fmt.Errorf("Error loading config: %w", err)
}

credsProvider, err := vault.NewTempCredentialsProvider(config, &vault.CredentialKeyring{Keyring: keyring}, input.NoSession)
credsProvider, err := vault.NewTempCredentialsProvider(config, &vault.CredentialKeyring{Keyring: keyring}, input.NoSession, true)
if err != nil {
return 0, fmt.Errorf("Error getting temporary credentials: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion cli/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func ExportCommand(input ExportCommandInput, f *vault.ConfigFile, keyring keyrin
}

ckr := &vault.CredentialKeyring{Keyring: keyring}
credsProvider, err := vault.NewTempCredentialsProvider(config, ckr, input.NoSession)
credsProvider, err := vault.NewTempCredentialsProvider(config, ckr, input.NoSession, true)
if err != nil {
return fmt.Errorf("Error getting temporary credentials: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion cli/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func LoginCommand(input LoginCommandInput, f *vault.ConfigFile, keyring keyring.
ckr := &vault.CredentialKeyring{Keyring: keyring}
if config.HasRole() || config.HasSSOStartURL() || config.HasCredentialProcess() || config.HasWebIdentity() {
// If AssumeRole or sso.GetRoleCredentials isn't used, GetFederationToken has to be used for IAM credentials
credsProvider, err = vault.NewTempCredentialsProvider(config, ckr, input.NoSession)
credsProvider, err = vault.NewTempCredentialsProvider(config, ckr, input.NoSession, true)
} else {
credsProvider, err = vault.NewFederationTokenCredentialsProvider(context.TODO(), input.ProfileName, ckr, config)
}
Expand Down
4 changes: 1 addition & 3 deletions cli/rotate.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ func ConfigureRotateCommand(app *kingpin.Application, a *AwsVault) {
}

func RotateCommand(input RotateCommandInput, f *vault.ConfigFile, keyring keyring.Keyring) error {
vault.UseSessionCache = false

configLoader := vault.NewConfigLoader(input.Config, f, input.ProfileName)
config, err := configLoader.GetProfileConfig(input.ProfileName)
if err != nil {
Expand Down Expand Up @@ -87,7 +85,7 @@ func RotateCommand(input RotateCommandInput, f *vault.ConfigFile, keyring keyrin
credsProvider = vault.NewMasterCredentialsProvider(ckr, config.ProfileName)
} else {
// Can't always disable sessions completely, might need to use session for MFA-Protected API Access
credsProvider, err = vault.NewTempCredentialsProvider(config, ckr, input.NoSession)
credsProvider, err = vault.NewTempCredentialsProvider(config, ckr, input.NoSession, false)
if err != nil {
return fmt.Errorf("Error getting temporary credentials: %w", err)
}
Expand Down
44 changes: 21 additions & 23 deletions vault/vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ func init() {
}
}

var UseSessionCache = true

func NewAwsConfig(region, stsRegionalEndpoints string) aws.Config {
return aws.Config{
Region: region,
Expand All @@ -48,7 +46,7 @@ func NewMasterCredentialsProvider(k *CredentialKeyring, credentialsName string)
return &KeyringProvider{k, credentialsName}
}

func NewSessionTokenProvider(credsProvider aws.CredentialsProvider, k keyring.Keyring, config *ProfileConfig) (aws.CredentialsProvider, error) {
func NewSessionTokenProvider(credsProvider aws.CredentialsProvider, k keyring.Keyring, config *ProfileConfig, useSessionCache bool) (aws.CredentialsProvider, error) {
cfg := NewAwsConfigWithCredsProvider(credsProvider, config.Region, config.STSRegionalEndpoints)

sessionTokenProvider := &SessionTokenProvider{
Expand All @@ -57,7 +55,7 @@ func NewSessionTokenProvider(credsProvider aws.CredentialsProvider, k keyring.Ke
Mfa: NewMfa(config),
}

if UseSessionCache {
if useSessionCache {
return &CachedSessionProvider{
SessionKey: SessionMetadata{
Type: "sts.GetSessionToken",
Expand All @@ -74,7 +72,7 @@ func NewSessionTokenProvider(credsProvider aws.CredentialsProvider, k keyring.Ke
}

// NewAssumeRoleProvider returns a provider that generates credentials using AssumeRole
func NewAssumeRoleProvider(credsProvider aws.CredentialsProvider, k keyring.Keyring, config *ProfileConfig) (aws.CredentialsProvider, error) {
func NewAssumeRoleProvider(credsProvider aws.CredentialsProvider, k keyring.Keyring, config *ProfileConfig, useSessionCache bool) (aws.CredentialsProvider, error) {
cfg := NewAwsConfigWithCredsProvider(credsProvider, config.Region, config.STSRegionalEndpoints)

p := &AssumeRoleProvider{
Expand All @@ -89,7 +87,7 @@ func NewAssumeRoleProvider(credsProvider aws.CredentialsProvider, k keyring.Keyr
Mfa: NewMfa(config),
}

if UseSessionCache && config.MfaSerial != "" {
if useSessionCache && config.MfaSerial != "" {
return &CachedSessionProvider{
SessionKey: SessionMetadata{
Type: "sts.AssumeRole",
Expand All @@ -107,7 +105,7 @@ func NewAssumeRoleProvider(credsProvider aws.CredentialsProvider, k keyring.Keyr

// NewAssumeRoleWithWebIdentityProvider returns a provider that generates
// credentials using AssumeRoleWithWebIdentity
func NewAssumeRoleWithWebIdentityProvider(k keyring.Keyring, config *ProfileConfig) (aws.CredentialsProvider, error) {
func NewAssumeRoleWithWebIdentityProvider(k keyring.Keyring, config *ProfileConfig, useSessionCache bool) (aws.CredentialsProvider, error) {
cfg := NewAwsConfig(config.Region, config.STSRegionalEndpoints)

p := &AssumeRoleWithWebIdentityProvider{
Expand All @@ -119,7 +117,7 @@ func NewAssumeRoleWithWebIdentityProvider(k keyring.Keyring, config *ProfileConf
Duration: config.AssumeRoleDuration,
}

if UseSessionCache {
if useSessionCache {
return &CachedSessionProvider{
SessionKey: SessionMetadata{
Type: "sts.AssumeRoleWithWebIdentity",
Expand All @@ -135,7 +133,7 @@ func NewAssumeRoleWithWebIdentityProvider(k keyring.Keyring, config *ProfileConf
}

// NewSSORoleCredentialsProvider creates a provider for SSO credentials
func NewSSORoleCredentialsProvider(k keyring.Keyring, config *ProfileConfig) (aws.CredentialsProvider, error) {
func NewSSORoleCredentialsProvider(k keyring.Keyring, config *ProfileConfig, useSessionCache bool) (aws.CredentialsProvider, error) {
cfg := NewAwsConfig(config.SSORegion, config.STSRegionalEndpoints)

ssoRoleCredentialsProvider := &SSORoleCredentialsProvider{
Expand All @@ -147,7 +145,7 @@ func NewSSORoleCredentialsProvider(k keyring.Keyring, config *ProfileConfig) (aw
UseStdout: config.SSOUseStdout,
}

if UseSessionCache {
if useSessionCache {
ssoRoleCredentialsProvider.OIDCTokenCache = OIDCTokenKeyring{Keyring: k}
return &CachedSessionProvider{
SessionKey: SessionMetadata{
Expand All @@ -166,12 +164,12 @@ func NewSSORoleCredentialsProvider(k keyring.Keyring, config *ProfileConfig) (aw

// NewCredentialProcessProvider creates a provider to retrieve credentials from an external
// executable as described in https://docs.aws.amazon.com/cli/latest/topic/config-vars.html#sourcing-credentials-from-external-processes
func NewCredentialProcessProvider(k keyring.Keyring, config *ProfileConfig) (aws.CredentialsProvider, error) {
func NewCredentialProcessProvider(k keyring.Keyring, config *ProfileConfig, useSessionCache bool) (aws.CredentialsProvider, error) {
credentialProcessProvider := &CredentialProcessProvider{
CredentialProcess: config.CredentialProcess,
}

if UseSessionCache {
if useSessionCache {
return &CachedSessionProvider{
SessionKey: SessionMetadata{
Type: "credential_process",
Expand Down Expand Up @@ -237,10 +235,10 @@ type tempCredsCreator struct {
chainedMfa string
}

func (t *tempCredsCreator) getSourceCreds(config *ProfileConfig) (sourcecredsProvider aws.CredentialsProvider, err error) {
func (t *tempCredsCreator) getSourceCreds(config *ProfileConfig, useSessionCache bool) (sourcecredsProvider aws.CredentialsProvider, err error) {
if config.HasSourceProfile() {
log.Printf("profile %s: sourcing credentials from profile %s", config.ProfileName, config.SourceProfile.ProfileName)
return t.GetProviderForProfile(config.SourceProfile)
return t.GetProviderForProfile(config.SourceProfile, useSessionCache)
}

hasStoredCredentials, err := t.Keyring.Has(config.ProfileName)
Expand All @@ -256,27 +254,27 @@ func (t *tempCredsCreator) getSourceCreds(config *ProfileConfig) (sourcecredsPro
return nil, fmt.Errorf("profile %s: credentials missing", config.ProfileName)
}

func (t *tempCredsCreator) GetProviderForProfile(config *ProfileConfig) (aws.CredentialsProvider, error) {
func (t *tempCredsCreator) GetProviderForProfile(config *ProfileConfig, useSessionCache bool) (aws.CredentialsProvider, error) {
if err := config.Validate(); err != nil {
return nil, err
}

if config.HasSSOStartURL() {
log.Printf("profile %s: using SSO role credentials", config.ProfileName)
return NewSSORoleCredentialsProvider(t.Keyring.Keyring, config)
return NewSSORoleCredentialsProvider(t.Keyring.Keyring, config, useSessionCache)
}

if config.HasWebIdentity() {
log.Printf("profile %s: using web identity", config.ProfileName)
return NewAssumeRoleWithWebIdentityProvider(t.Keyring.Keyring, config)
return NewAssumeRoleWithWebIdentityProvider(t.Keyring.Keyring, config, useSessionCache)
}

if config.HasCredentialProcess() {
log.Printf("profile %s: using credential process", config.ProfileName)
return NewCredentialProcessProvider(t.Keyring.Keyring, config)
return NewCredentialProcessProvider(t.Keyring.Keyring, config, useSessionCache)
}

sourcecredsProvider, err := t.getSourceCreds(config)
sourcecredsProvider, err := t.getSourceCreds(config, useSessionCache)
if err != nil {
return nil, err
}
Expand All @@ -287,14 +285,14 @@ func (t *tempCredsCreator) GetProviderForProfile(config *ProfileConfig) (aws.Cre
config.MfaSerial = ""
}
log.Printf("profile %s: using AssumeRole %s", config.ProfileName, mfaDetails(isMfaChained, config))
return NewAssumeRoleProvider(sourcecredsProvider, t.Keyring.Keyring, config)
return NewAssumeRoleProvider(sourcecredsProvider, t.Keyring.Keyring, config, useSessionCache)
}

canUseGetSessionToken, reason := t.canUseGetSessionToken(config)
if canUseGetSessionToken {
t.chainedMfa = config.MfaSerial
log.Printf("profile %s: using GetSessionToken %s", config.ProfileName, mfaDetails(false, config))
return NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config)
return NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config, useSessionCache)
}

log.Printf("profile %s: skipping GetSessionToken because %s", config.ProfileName, reason)
Expand Down Expand Up @@ -339,10 +337,10 @@ func mfaDetails(mfaChained bool, config *ProfileConfig) string {
}

// NewTempCredentialsProvider creates a credential provider for the given config
func NewTempCredentialsProvider(config *ProfileConfig, keyring *CredentialKeyring, disableSessions bool) (aws.CredentialsProvider, error) {
func NewTempCredentialsProvider(config *ProfileConfig, keyring *CredentialKeyring, disableSessions bool, useSessionCache bool) (aws.CredentialsProvider, error) {
t := tempCredsCreator{
Keyring: keyring,
DisableSessions: disableSessions,
}
return t.GetProviderForProfile(config)
return t.GetProviderForProfile(config, useSessionCache)
}

0 comments on commit ec5e53c

Please sign in to comment.