From 36c2635d95749f8b94e62bf5b11de3e48cce32c9 Mon Sep 17 00:00:00 2001 From: Jonathan Stewmon Date: Mon, 28 Oct 2019 13:01:27 -0500 Subject: [PATCH 1/2] fix: session cache key should vary with mfa serial If the master credentials were used as the source for: - role A without mfa_serial - role B with mfa_serial Then, creating a cached session with role A would result in operations with role B failing because the temporary session established when assuming role A would not have been created with an mfa token. The mfa serial is now encoded in the cache key, so that the scenario mentioned above produces distinct cache keys - one with an mfa serial and one without an mfa serial. This also resolves a potentially unexpected behavior when a user changes the value of a mfa_serial attribute. Because the session key format has changed, any sessions matching the previous pattern will be identified as obsolete and removed, similarly to how expired sessions are removed as they are visited. --- vault/provider.go | 4 ++-- vault/sessions.go | 51 ++++++++++++++++++++++++++++++++---------- vault/sessions_test.go | 2 ++ 3 files changed, 43 insertions(+), 14 deletions(-) diff --git a/vault/provider.go b/vault/provider.go index 10166ff84..60ce20fd9 100644 --- a/vault/provider.go +++ b/vault/provider.go @@ -100,7 +100,7 @@ func (p *VaultProvider) Retrieve() (credentials.Value, error) { } // sessions get stored by profile, not the source - session, err := p.sessions.Retrieve(p.profile) + session, err := p.sessions.Retrieve(p.profile, p.VaultOptions.MfaSerial) if err != nil { if err == keyring.ErrKeyNotFound { log.Printf("Session not found in keyring for %s", p.profile) @@ -138,7 +138,7 @@ func (p *VaultProvider) Retrieve() (credentials.Value, error) { return credentials.Value{}, err } - if err = p.sessions.Store(p.profile, session, time.Now().Add(p.SessionDuration)); err != nil { + if err = p.sessions.Store(p.profile, p.VaultOptions.MfaSerial, session); err != nil { return credentials.Value{}, err } } diff --git a/vault/sessions.go b/vault/sessions.go index c13b12b76..d9cb91f91 100644 --- a/vault/sessions.go +++ b/vault/sessions.go @@ -1,6 +1,7 @@ package vault import ( + "encoding/base64" "encoding/json" "errors" "fmt" @@ -13,25 +14,46 @@ import ( "github.com/aws/aws-sdk-go/service/sts" ) -var sessionKeyPattern = regexp.MustCompile(`^(.+?) session \((\d+)\)$`) +var sessionKeyPattern = regexp.MustCompile(`^session:(?P[^:]+):(?P[^:]*):(?P[^:]+)$`) +var oldSessionKeyPatterns = []*regexp.Regexp{ + regexp.MustCompile(`^(.+?) session \((\d+)\)$`), +} +var base64Encoding = base64.URLEncoding.WithPadding(base64.NoPadding) func IsSessionKey(s string) bool { - return sessionKeyPattern.MatchString(s) + if sessionKeyPattern.MatchString(s) { + return true + } + for _, pattern := range oldSessionKeyPatterns { + if pattern.MatchString(s) { + return true + } + } + return false } func parseKeyringSession(s string, conf *Config) (KeyringSession, error) { matches := sessionKeyPattern.FindStringSubmatch(s) if len(matches) == 0 { - return KeyringSession{}, errors.New("Failed to parse session name") + return KeyringSession{}, errors.New("failed to parse session name") } - profile, _ := conf.Profile(matches[1]) - return KeyringSession{Profile: profile, Name: s, SessionID: matches[2]}, nil + profileName, _ := base64Encoding.DecodeString(matches[1]) + mfaSerial, _ := base64Encoding.DecodeString(matches[2]) + sessionId := matches[3] + profile, _ := conf.Profile(string(profileName)) + return KeyringSession{ + Profile: profile, + Name: s, + SessionID: sessionId, + MfaSerial: string(mfaSerial), + }, nil } type KeyringSession struct { Profile Name string SessionID string + MfaSerial string } func (ks KeyringSession) IsExpired() bool { @@ -71,9 +93,9 @@ func (s *KeyringSessions) Sessions() ([]KeyringSession, error) { for _, k := range keys { if IsSessionKey(k) { - ks, _ := parseKeyringSession(k, s.Config) - if ks.IsExpired() { - log.Printf("Session %s is expired, attempting deleting", k) + ks, err := parseKeyringSession(k, s.Config) + if err != nil || ks.IsExpired() { + log.Printf("Session %s is obsolete, attempting deleting", k) if err := s.Keyring.Remove(k); err != nil { log.Printf("Error deleting session: %v", err) } @@ -88,7 +110,7 @@ func (s *KeyringSessions) Sessions() ([]KeyringSession, error) { } // Retrieve searches sessions for specific profile, expects the profile to be provided, not the source -func (s *KeyringSessions) Retrieve(profile string) (creds sts.Credentials, err error) { +func (s *KeyringSessions) Retrieve(profile string, mfaSerial string) (creds sts.Credentials, err error) { log.Printf("Looking for sessions for %s", profile) sessions, err := s.Sessions() if err != nil { @@ -96,7 +118,7 @@ func (s *KeyringSessions) Retrieve(profile string) (creds sts.Credentials, err e } for _, session := range sessions { - if session.Profile.Name == profile { + if session.Profile.Name == profile && session.MfaSerial == mfaSerial { item, err := s.Keyring.Get(session.Name) if err != nil { return creds, err @@ -123,13 +145,18 @@ func (s *KeyringSessions) Retrieve(profile string) (creds sts.Credentials, err e } // Store stores a sessions for a specific profile, expects the profile to be provided, not the source -func (s *KeyringSessions) Store(profile string, session sts.Credentials, expires time.Time) error { +func (s *KeyringSessions) Store(profile string, mfaSerial string, session sts.Credentials) error { bytes, err := json.Marshal(session) if err != nil { return err } - key := fmt.Sprintf("%s session (%d)", profile, expires.Unix()) + key := fmt.Sprintf( + "session:%s:%s:%d", + base64Encoding.EncodeToString([]byte(profile)), + base64Encoding.EncodeToString([]byte(mfaSerial)), + session.Expiration.Unix(), + ) log.Printf("Writing session for %s to keyring: %q", profile, key) return s.Keyring.Set(keyring.Item{ diff --git a/vault/sessions_test.go b/vault/sessions_test.go index 2e7fc7f96..2db433f0e 100644 --- a/vault/sessions_test.go +++ b/vault/sessions_test.go @@ -14,6 +14,8 @@ func TestIsSessionKey(t *testing.T) { {"blah", false}, {"blah session (61633665646639303539)", true}, {"blah-iam session (32383863333237616430)", true}, + {"session:c2Vzc2lvbg::1572281751", true}, + {"session:c2Vzc2lvbg:YXJuOmF3czppYW06OjEyMzQ1Njc4OTA6bWZhL2pzdGV3bW9u:1572281751", true}, } for _, tc := range testCases { From ad2e94a21b727322b01820c15cc2b545df641a26 Mon Sep 17 00:00:00 2001 From: Jonathan Stewmon Date: Tue, 29 Oct 2019 11:40:46 -0500 Subject: [PATCH 2/2] fix: format session key for filename compatibility The session key has to be compatible with all Keyring backend stores, which includes a File store that uses the key as a filename. Colon is a disallowed filename character on Windows, so the separator was changed to a comma. KeyringSession fields have been updated to reflect the structured data encoded in the session key, and the parser now only returns a populated KeyringSession if the key was in the expected format. --- cli/ls.go | 10 ++++-- vault/sessions.go | 75 ++++++++++++++++++++++-------------------- vault/sessions_test.go | 4 +-- 3 files changed, 49 insertions(+), 40 deletions(-) diff --git a/cli/ls.go b/cli/ls.go index b70038b4b..8d1d116c5 100644 --- a/cli/ls.go +++ b/cli/ls.go @@ -108,15 +108,19 @@ func LsCommand(app *kingpin.Application, input LsCommandInput) { fmt.Fprintf(w, "-\t") } - var sessionIDs []string + var sessionLabels []string for _, sess := range sessions { if profile.Name == sess.Profile.Name { - sessionIDs = append(sessionIDs, sess.SessionID) + label := fmt.Sprintf("%d", sess.Expiration.Unix()) + if sess.MfaSerial != "" { + label += " (mfa)" + } + sessionLabels = append(sessionLabels, label) } } if len(sessions) > 0 { - fmt.Fprintf(w, "%s\t\n", strings.Join(sessionIDs, ", ")) + fmt.Fprintf(w, "%s\t\n", strings.Join(sessionLabels, ", ")) } else { fmt.Fprintf(w, "-\t\n") } diff --git a/vault/sessions.go b/vault/sessions.go index d9cb91f91..0c340773f 100644 --- a/vault/sessions.go +++ b/vault/sessions.go @@ -14,8 +14,9 @@ import ( "github.com/aws/aws-sdk-go/service/sts" ) -var sessionKeyPattern = regexp.MustCompile(`^session:(?P[^:]+):(?P[^:]*):(?P[^:]+)$`) +var sessionKeyPattern = regexp.MustCompile(`^session,(?P[^,]+),(?P[^,]*),(?P[^:]+)$`) var oldSessionKeyPatterns = []*regexp.Regexp{ + regexp.MustCompile(`^session:(?P[^ ]+):(?P[^ ]*):(?P[^:]+)$`), regexp.MustCompile(`^(.+?) session \((\d+)\)$`), } var base64Encoding = base64.URLEncoding.WithPadding(base64.NoPadding) @@ -32,42 +33,51 @@ func IsSessionKey(s string) bool { return false } -func parseKeyringSession(s string, conf *Config) (KeyringSession, error) { - matches := sessionKeyPattern.FindStringSubmatch(s) +func parseSessionKey(key string, conf *Config) (KeyringSession, error) { + matches := sessionKeyPattern.FindStringSubmatch(key) if len(matches) == 0 { return KeyringSession{}, errors.New("failed to parse session name") } - profileName, _ := base64Encoding.DecodeString(matches[1]) - mfaSerial, _ := base64Encoding.DecodeString(matches[2]) - sessionId := matches[3] + profileName, err := base64Encoding.DecodeString(matches[1]) + if err != nil { + return KeyringSession{}, err + } + mfaSerial, err := base64Encoding.DecodeString(matches[2]) + if err != nil { + return KeyringSession{}, err + } + tsInt, err := strconv.ParseInt(matches[3], 10, 64) + if err != nil { + return KeyringSession{}, err + } profile, _ := conf.Profile(string(profileName)) return KeyringSession{ - Profile: profile, - Name: s, - SessionID: sessionId, - MfaSerial: string(mfaSerial), + Profile: profile, + Key: key, + Expiration: time.Unix(tsInt, 0), + MfaSerial: string(mfaSerial), }, nil } +func formatSessionKey(profile string, mfaSerial string, expiration *time.Time) string { + return fmt.Sprintf( + "session,%s,%s,%d", + base64Encoding.EncodeToString([]byte(profile)), + base64Encoding.EncodeToString([]byte(mfaSerial)), + expiration.Unix(), + ) +} + type KeyringSession struct { Profile - Name string - SessionID string - MfaSerial string + Key string + Expiration time.Time + MfaSerial string } func (ks KeyringSession) IsExpired() bool { - // Older sessions were 20 characters long and opaque identifiers - if len(ks.SessionID) == 20 { - return true - } - // Now our session id's are timestamps - tsInt, err := strconv.ParseInt(ks.SessionID, 10, 64) - if err != nil { - return true - } - log.Printf("Session %q expires in %v", ks.Name, time.Unix(tsInt, 0).Sub(time.Now()).String()) - return time.Now().After(time.Unix(tsInt, 0)) + log.Printf("Session %q expires in %v", ks.Key, ks.Expiration.Sub(time.Now()).String()) + return time.Now().After(ks.Expiration) } type KeyringSessions struct { @@ -93,7 +103,7 @@ func (s *KeyringSessions) Sessions() ([]KeyringSession, error) { for _, k := range keys { if IsSessionKey(k) { - ks, err := parseKeyringSession(k, s.Config) + ks, err := parseSessionKey(k, s.Config) if err != nil || ks.IsExpired() { log.Printf("Session %s is obsolete, attempting deleting", k) if err := s.Keyring.Remove(k); err != nil { @@ -119,7 +129,7 @@ func (s *KeyringSessions) Retrieve(profile string, mfaSerial string) (creds sts. for _, session := range sessions { if session.Profile.Name == profile && session.MfaSerial == mfaSerial { - item, err := s.Keyring.Get(session.Name) + item, err := s.Keyring.Get(session.Key) if err != nil { return creds, err } @@ -130,7 +140,7 @@ func (s *KeyringSessions) Retrieve(profile string, mfaSerial string) (creds sts. // double check the actual expiry time if creds.Expiration.Before(time.Now()) { - log.Printf("Session %q is expired, deleting", session.Name) + log.Printf("Session %q is expired, deleting", session.Key) if err = s.Keyring.Remove(session.Profile.Name); err != nil { return creds, err } @@ -151,12 +161,7 @@ func (s *KeyringSessions) Store(profile string, mfaSerial string, session sts.Cr return err } - key := fmt.Sprintf( - "session:%s:%s:%d", - base64Encoding.EncodeToString([]byte(profile)), - base64Encoding.EncodeToString([]byte(mfaSerial)), - session.Expiration.Unix(), - ) + key := formatSessionKey(profile, mfaSerial, session.Expiration) log.Printf("Writing session for %s to keyring: %q", profile, key) return s.Keyring.Set(keyring.Item{ @@ -180,8 +185,8 @@ func (s *KeyringSessions) Delete(profile string) (n int, err error) { for _, session := range sessions { if session.Profile.Name == profile { - log.Printf("Session %q matches profile %q", session.Name, profile) - if err = s.Keyring.Remove(session.Name); err != nil { + log.Printf("Session %q matches profile %q", session.Key, profile) + if err = s.Keyring.Remove(session.Key); err != nil { return n, err } n++ diff --git a/vault/sessions_test.go b/vault/sessions_test.go index 2db433f0e..3802eda01 100644 --- a/vault/sessions_test.go +++ b/vault/sessions_test.go @@ -14,8 +14,8 @@ func TestIsSessionKey(t *testing.T) { {"blah", false}, {"blah session (61633665646639303539)", true}, {"blah-iam session (32383863333237616430)", true}, - {"session:c2Vzc2lvbg::1572281751", true}, - {"session:c2Vzc2lvbg:YXJuOmF3czppYW06OjEyMzQ1Njc4OTA6bWZhL2pzdGV3bW9u:1572281751", true}, + {"session,c2Vzc2lvbg,,1572281751", true}, + {"session,c2Vzc2lvbg,YXJuOmF3czppYW06OjEyMzQ1Njc4OTA6bWZhL2pzdGV3bW9u,1572281751", true}, } for _, tc := range testCases {