Skip to content

Commit

Permalink
Report number of cleared sessions more accurately
Browse files Browse the repository at this point in the history
  • Loading branch information
mtibben committed Sep 1, 2020
1 parent 6d1f063 commit 56986de
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 27 deletions.
8 changes: 6 additions & 2 deletions cli/clear.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,13 @@ func ConfigureClearCommand(app *kingpin.Application, a *AwsVault) {
func ClearCommand(input ClearCommandInput, awsConfigFile *vault.ConfigFile, keyring keyring.Keyring) error {
sessions := &vault.SessionKeyring{Keyring: keyring}
oidcTokens := &vault.OIDCTokenKeyring{Keyring: keyring}
var numSessionsRemoved, numTokensRemoved int
var oldSessionsRemoved, numSessionsRemoved, numTokensRemoved int
var err error
if input.ProfileName == "" {
oldSessionsRemoved, err = sessions.RemoveOldSessions()
if err != nil {
return err
}
numSessionsRemoved, err = sessions.RemoveAll()
if err != nil {
return err
Expand All @@ -67,7 +71,7 @@ func ClearCommand(input ClearCommandInput, awsConfigFile *vault.ConfigFile, keyr
}
}
}
fmt.Printf("Cleared %d sessions.\n", numSessionsRemoved+numTokensRemoved)
fmt.Printf("Cleared %d sessions.\n", oldSessionsRemoved+numSessionsRemoved+numTokensRemoved)

return nil
}
36 changes: 11 additions & 25 deletions vault/sessionkeyring.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ var oldSessionKeyPatterns = []*regexp.Regexp{
regexp.MustCompile(`^session:(?P<profile>[^ ]+):(?P<mfaSerial>[^ ]*):(?P<expiration>[^:]+)$`),
regexp.MustCompile(`^(.+?) session \((\d+)\)$`),
}
var base64Encoding = base64.URLEncoding.WithPadding(base64.NoPadding)
var base64URLEncodingNoPadding = base64.URLEncoding.WithPadding(base64.NoPadding)

func IsOldSessionKey(s string) bool {
for _, pattern := range oldSessionKeyPatterns {
Expand Down Expand Up @@ -52,8 +52,8 @@ func (k *SessionMetadata) String() string {
return fmt.Sprintf(
"%s,%s,%s,%d",
k.Type,
base64Encoding.EncodeToString([]byte(k.ProfileName)),
base64Encoding.EncodeToString([]byte(k.MfaSerial)),
base64URLEncodingNoPadding.EncodeToString([]byte(k.ProfileName)),
base64URLEncodingNoPadding.EncodeToString([]byte(k.MfaSerial)),
k.Expiration.Unix(),
)
}
Expand All @@ -62,8 +62,8 @@ func (k *SessionMetadata) StringForMatching() string {
return fmt.Sprintf(
"%s,%s,%s,",
k.Type,
base64Encoding.EncodeToString([]byte(k.ProfileName)),
base64Encoding.EncodeToString([]byte(k.MfaSerial)),
base64URLEncodingNoPadding.EncodeToString([]byte(k.ProfileName)),
base64URLEncodingNoPadding.EncodeToString([]byte(k.MfaSerial)),
)
}

Expand All @@ -73,11 +73,11 @@ func NewSessionKeyFromString(s string) (SessionMetadata, error) {
return SessionMetadata{}, fmt.Errorf("failed to parse session name: %s", s)
}

profileName, err := base64Encoding.DecodeString(matches[2])
profileName, err := base64URLEncodingNoPadding.DecodeString(matches[2])
if err != nil {
return SessionMetadata{}, err
}
mfaSerial, err := base64Encoding.DecodeString(matches[3])
mfaSerial, err := base64URLEncodingNoPadding.DecodeString(matches[3])
if err != nil {
return SessionMetadata{}, err
}
Expand All @@ -95,8 +95,7 @@ func NewSessionKeyFromString(s string) (SessionMetadata, error) {
}

type SessionKeyring struct {
Keyring keyring.Keyring
isGarbageCollected bool
Keyring keyring.Keyring
}

var ErrNotFound = keyring.ErrKeyNotFound
Expand Down Expand Up @@ -127,7 +126,7 @@ func (sk *SessionKeyring) Has(key SessionMetadata) (bool, error) {
}

func (sk *SessionKeyring) Get(key SessionMetadata) (val *sts.Credentials, err error) {
sk.garbageCollect()
sk.RemoveOldSessions()

keyName, err := sk.lookupKeyName(key)
if err != nil && err != ErrNotFound {
Expand All @@ -145,7 +144,7 @@ func (sk *SessionKeyring) Get(key SessionMetadata) (val *sts.Credentials, err er
}

func (sk *SessionKeyring) Set(key SessionMetadata, val *sts.Credentials) error {
sk.garbageCollectOnce()
sk.RemoveOldSessions()

key.Expiration = *val.Expiration

Expand Down Expand Up @@ -176,8 +175,6 @@ func (sk *SessionKeyring) Set(key SessionMetadata, val *sts.Credentials) error {
}

func (sk *SessionKeyring) Remove(key SessionMetadata) error {
sk.garbageCollectOnce()

keyName, err := sk.lookupKeyName(key)
if err != nil && err != ErrNotFound {
return err
Expand All @@ -187,8 +184,6 @@ func (sk *SessionKeyring) Remove(key SessionMetadata) error {
}

func (sk *SessionKeyring) RemoveAll() (n int, err error) {
sk.garbageCollectOnce()

allKeys, err := sk.Keys()
if err != nil {
return 0, err
Expand Down Expand Up @@ -265,16 +260,7 @@ func (sk *SessionKeyring) RemoveForProfile(profileName string) (n int, err error
return n, nil
}

func (sk *SessionKeyring) garbageCollectOnce() (n int, err error) {
if sk.isGarbageCollected {
return
}
return sk.garbageCollect()
}

func (sk *SessionKeyring) garbageCollect() (n int, err error) {
sk.isGarbageCollected = true

func (sk *SessionKeyring) RemoveOldSessions() (n int, err error) {
allKeys, err := sk.Keyring.Keys()
if err != nil {
log.Printf("Error while deleting old session: %s", err.Error())
Expand Down

0 comments on commit 56986de

Please sign in to comment.