From 56986dee624f64b6b83c9ac12339fe053f5f407f Mon Sep 17 00:00:00 2001 From: Michael Tibben Date: Tue, 1 Sep 2020 13:48:58 +1000 Subject: [PATCH] Report number of cleared sessions more accurately --- cli/clear.go | 8 ++++++-- vault/sessionkeyring.go | 36 +++++++++++------------------------- 2 files changed, 17 insertions(+), 27 deletions(-) diff --git a/cli/clear.go b/cli/clear.go index c7cd3ba65..64242752d 100644 --- a/cli/clear.go +++ b/cli/clear.go @@ -40,9 +40,13 @@ func ConfigureClearCommand(app *kingpin.Application, a *AwsVault) { func ClearCommand(input ClearCommandInput, awsConfigFile *vault.ConfigFile, keyring keyring.Keyring) error { sessions := &vault.SessionKeyring{Keyring: keyring} oidcTokens := &vault.OIDCTokenKeyring{Keyring: keyring} - var numSessionsRemoved, numTokensRemoved int + var oldSessionsRemoved, numSessionsRemoved, numTokensRemoved int var err error if input.ProfileName == "" { + oldSessionsRemoved, err = sessions.RemoveOldSessions() + if err != nil { + return err + } numSessionsRemoved, err = sessions.RemoveAll() if err != nil { return err @@ -67,7 +71,7 @@ func ClearCommand(input ClearCommandInput, awsConfigFile *vault.ConfigFile, keyr } } } - fmt.Printf("Cleared %d sessions.\n", numSessionsRemoved+numTokensRemoved) + fmt.Printf("Cleared %d sessions.\n", oldSessionsRemoved+numSessionsRemoved+numTokensRemoved) return nil } diff --git a/vault/sessionkeyring.go b/vault/sessionkeyring.go index 089758f41..255b2cdde 100644 --- a/vault/sessionkeyring.go +++ b/vault/sessionkeyring.go @@ -21,7 +21,7 @@ var oldSessionKeyPatterns = []*regexp.Regexp{ regexp.MustCompile(`^session:(?P[^ ]+):(?P[^ ]*):(?P[^:]+)$`), regexp.MustCompile(`^(.+?) session \((\d+)\)$`), } -var base64Encoding = base64.URLEncoding.WithPadding(base64.NoPadding) +var base64URLEncodingNoPadding = base64.URLEncoding.WithPadding(base64.NoPadding) func IsOldSessionKey(s string) bool { for _, pattern := range oldSessionKeyPatterns { @@ -52,8 +52,8 @@ func (k *SessionMetadata) String() string { return fmt.Sprintf( "%s,%s,%s,%d", k.Type, - base64Encoding.EncodeToString([]byte(k.ProfileName)), - base64Encoding.EncodeToString([]byte(k.MfaSerial)), + base64URLEncodingNoPadding.EncodeToString([]byte(k.ProfileName)), + base64URLEncodingNoPadding.EncodeToString([]byte(k.MfaSerial)), k.Expiration.Unix(), ) } @@ -62,8 +62,8 @@ func (k *SessionMetadata) StringForMatching() string { return fmt.Sprintf( "%s,%s,%s,", k.Type, - base64Encoding.EncodeToString([]byte(k.ProfileName)), - base64Encoding.EncodeToString([]byte(k.MfaSerial)), + base64URLEncodingNoPadding.EncodeToString([]byte(k.ProfileName)), + base64URLEncodingNoPadding.EncodeToString([]byte(k.MfaSerial)), ) } @@ -73,11 +73,11 @@ func NewSessionKeyFromString(s string) (SessionMetadata, error) { return SessionMetadata{}, fmt.Errorf("failed to parse session name: %s", s) } - profileName, err := base64Encoding.DecodeString(matches[2]) + profileName, err := base64URLEncodingNoPadding.DecodeString(matches[2]) if err != nil { return SessionMetadata{}, err } - mfaSerial, err := base64Encoding.DecodeString(matches[3]) + mfaSerial, err := base64URLEncodingNoPadding.DecodeString(matches[3]) if err != nil { return SessionMetadata{}, err } @@ -95,8 +95,7 @@ func NewSessionKeyFromString(s string) (SessionMetadata, error) { } type SessionKeyring struct { - Keyring keyring.Keyring - isGarbageCollected bool + Keyring keyring.Keyring } var ErrNotFound = keyring.ErrKeyNotFound @@ -127,7 +126,7 @@ func (sk *SessionKeyring) Has(key SessionMetadata) (bool, error) { } func (sk *SessionKeyring) Get(key SessionMetadata) (val *sts.Credentials, err error) { - sk.garbageCollect() + sk.RemoveOldSessions() keyName, err := sk.lookupKeyName(key) if err != nil && err != ErrNotFound { @@ -145,7 +144,7 @@ func (sk *SessionKeyring) Get(key SessionMetadata) (val *sts.Credentials, err er } func (sk *SessionKeyring) Set(key SessionMetadata, val *sts.Credentials) error { - sk.garbageCollectOnce() + sk.RemoveOldSessions() key.Expiration = *val.Expiration @@ -176,8 +175,6 @@ func (sk *SessionKeyring) Set(key SessionMetadata, val *sts.Credentials) error { } func (sk *SessionKeyring) Remove(key SessionMetadata) error { - sk.garbageCollectOnce() - keyName, err := sk.lookupKeyName(key) if err != nil && err != ErrNotFound { return err @@ -187,8 +184,6 @@ func (sk *SessionKeyring) Remove(key SessionMetadata) error { } func (sk *SessionKeyring) RemoveAll() (n int, err error) { - sk.garbageCollectOnce() - allKeys, err := sk.Keys() if err != nil { return 0, err @@ -265,16 +260,7 @@ func (sk *SessionKeyring) RemoveForProfile(profileName string) (n int, err error return n, nil } -func (sk *SessionKeyring) garbageCollectOnce() (n int, err error) { - if sk.isGarbageCollected { - return - } - return sk.garbageCollect() -} - -func (sk *SessionKeyring) garbageCollect() (n int, err error) { - sk.isGarbageCollected = true - +func (sk *SessionKeyring) RemoveOldSessions() (n int, err error) { allKeys, err := sk.Keyring.Keys() if err != nil { log.Printf("Error while deleting old session: %s", err.Error())