diff --git a/cli/ls.go b/cli/ls.go index b70038b4b..8d1d116c5 100644 --- a/cli/ls.go +++ b/cli/ls.go @@ -108,15 +108,19 @@ func LsCommand(app *kingpin.Application, input LsCommandInput) { fmt.Fprintf(w, "-\t") } - var sessionIDs []string + var sessionLabels []string for _, sess := range sessions { if profile.Name == sess.Profile.Name { - sessionIDs = append(sessionIDs, sess.SessionID) + label := fmt.Sprintf("%d", sess.Expiration.Unix()) + if sess.MfaSerial != "" { + label += " (mfa)" + } + sessionLabels = append(sessionLabels, label) } } if len(sessions) > 0 { - fmt.Fprintf(w, "%s\t\n", strings.Join(sessionIDs, ", ")) + fmt.Fprintf(w, "%s\t\n", strings.Join(sessionLabels, ", ")) } else { fmt.Fprintf(w, "-\t\n") } diff --git a/vault/provider.go b/vault/provider.go index 10166ff84..60ce20fd9 100644 --- a/vault/provider.go +++ b/vault/provider.go @@ -100,7 +100,7 @@ func (p *VaultProvider) Retrieve() (credentials.Value, error) { } // sessions get stored by profile, not the source - session, err := p.sessions.Retrieve(p.profile) + session, err := p.sessions.Retrieve(p.profile, p.VaultOptions.MfaSerial) if err != nil { if err == keyring.ErrKeyNotFound { log.Printf("Session not found in keyring for %s", p.profile) @@ -138,7 +138,7 @@ func (p *VaultProvider) Retrieve() (credentials.Value, error) { return credentials.Value{}, err } - if err = p.sessions.Store(p.profile, session, time.Now().Add(p.SessionDuration)); err != nil { + if err = p.sessions.Store(p.profile, p.VaultOptions.MfaSerial, session); err != nil { return credentials.Value{}, err } } diff --git a/vault/sessions.go b/vault/sessions.go index c13b12b76..0c340773f 100644 --- a/vault/sessions.go +++ b/vault/sessions.go @@ -1,6 +1,7 @@ package vault import ( + "encoding/base64" "encoding/json" "errors" "fmt" @@ -13,39 +14,70 @@ import ( "github.com/aws/aws-sdk-go/service/sts" ) -var sessionKeyPattern = regexp.MustCompile(`^(.+?) session \((\d+)\)$`) +var sessionKeyPattern = regexp.MustCompile(`^session,(?P[^,]+),(?P[^,]*),(?P[^:]+)$`) +var oldSessionKeyPatterns = []*regexp.Regexp{ + regexp.MustCompile(`^session:(?P[^ ]+):(?P[^ ]*):(?P[^:]+)$`), + regexp.MustCompile(`^(.+?) session \((\d+)\)$`), +} +var base64Encoding = base64.URLEncoding.WithPadding(base64.NoPadding) func IsSessionKey(s string) bool { - return sessionKeyPattern.MatchString(s) + if sessionKeyPattern.MatchString(s) { + return true + } + for _, pattern := range oldSessionKeyPatterns { + if pattern.MatchString(s) { + return true + } + } + return false } -func parseKeyringSession(s string, conf *Config) (KeyringSession, error) { - matches := sessionKeyPattern.FindStringSubmatch(s) +func parseSessionKey(key string, conf *Config) (KeyringSession, error) { + matches := sessionKeyPattern.FindStringSubmatch(key) if len(matches) == 0 { - return KeyringSession{}, errors.New("Failed to parse session name") + return KeyringSession{}, errors.New("failed to parse session name") + } + profileName, err := base64Encoding.DecodeString(matches[1]) + if err != nil { + return KeyringSession{}, err + } + mfaSerial, err := base64Encoding.DecodeString(matches[2]) + if err != nil { + return KeyringSession{}, err + } + tsInt, err := strconv.ParseInt(matches[3], 10, 64) + if err != nil { + return KeyringSession{}, err } - profile, _ := conf.Profile(matches[1]) - return KeyringSession{Profile: profile, Name: s, SessionID: matches[2]}, nil + profile, _ := conf.Profile(string(profileName)) + return KeyringSession{ + Profile: profile, + Key: key, + Expiration: time.Unix(tsInt, 0), + MfaSerial: string(mfaSerial), + }, nil +} + +func formatSessionKey(profile string, mfaSerial string, expiration *time.Time) string { + return fmt.Sprintf( + "session,%s,%s,%d", + base64Encoding.EncodeToString([]byte(profile)), + base64Encoding.EncodeToString([]byte(mfaSerial)), + expiration.Unix(), + ) } type KeyringSession struct { Profile - Name string - SessionID string + Key string + Expiration time.Time + MfaSerial string } func (ks KeyringSession) IsExpired() bool { - // Older sessions were 20 characters long and opaque identifiers - if len(ks.SessionID) == 20 { - return true - } - // Now our session id's are timestamps - tsInt, err := strconv.ParseInt(ks.SessionID, 10, 64) - if err != nil { - return true - } - log.Printf("Session %q expires in %v", ks.Name, time.Unix(tsInt, 0).Sub(time.Now()).String()) - return time.Now().After(time.Unix(tsInt, 0)) + log.Printf("Session %q expires in %v", ks.Key, ks.Expiration.Sub(time.Now()).String()) + return time.Now().After(ks.Expiration) } type KeyringSessions struct { @@ -71,9 +103,9 @@ func (s *KeyringSessions) Sessions() ([]KeyringSession, error) { for _, k := range keys { if IsSessionKey(k) { - ks, _ := parseKeyringSession(k, s.Config) - if ks.IsExpired() { - log.Printf("Session %s is expired, attempting deleting", k) + ks, err := parseSessionKey(k, s.Config) + if err != nil || ks.IsExpired() { + log.Printf("Session %s is obsolete, attempting deleting", k) if err := s.Keyring.Remove(k); err != nil { log.Printf("Error deleting session: %v", err) } @@ -88,7 +120,7 @@ func (s *KeyringSessions) Sessions() ([]KeyringSession, error) { } // Retrieve searches sessions for specific profile, expects the profile to be provided, not the source -func (s *KeyringSessions) Retrieve(profile string) (creds sts.Credentials, err error) { +func (s *KeyringSessions) Retrieve(profile string, mfaSerial string) (creds sts.Credentials, err error) { log.Printf("Looking for sessions for %s", profile) sessions, err := s.Sessions() if err != nil { @@ -96,8 +128,8 @@ func (s *KeyringSessions) Retrieve(profile string) (creds sts.Credentials, err e } for _, session := range sessions { - if session.Profile.Name == profile { - item, err := s.Keyring.Get(session.Name) + if session.Profile.Name == profile && session.MfaSerial == mfaSerial { + item, err := s.Keyring.Get(session.Key) if err != nil { return creds, err } @@ -108,7 +140,7 @@ func (s *KeyringSessions) Retrieve(profile string) (creds sts.Credentials, err e // double check the actual expiry time if creds.Expiration.Before(time.Now()) { - log.Printf("Session %q is expired, deleting", session.Name) + log.Printf("Session %q is expired, deleting", session.Key) if err = s.Keyring.Remove(session.Profile.Name); err != nil { return creds, err } @@ -123,13 +155,13 @@ func (s *KeyringSessions) Retrieve(profile string) (creds sts.Credentials, err e } // Store stores a sessions for a specific profile, expects the profile to be provided, not the source -func (s *KeyringSessions) Store(profile string, session sts.Credentials, expires time.Time) error { +func (s *KeyringSessions) Store(profile string, mfaSerial string, session sts.Credentials) error { bytes, err := json.Marshal(session) if err != nil { return err } - key := fmt.Sprintf("%s session (%d)", profile, expires.Unix()) + key := formatSessionKey(profile, mfaSerial, session.Expiration) log.Printf("Writing session for %s to keyring: %q", profile, key) return s.Keyring.Set(keyring.Item{ @@ -153,8 +185,8 @@ func (s *KeyringSessions) Delete(profile string) (n int, err error) { for _, session := range sessions { if session.Profile.Name == profile { - log.Printf("Session %q matches profile %q", session.Name, profile) - if err = s.Keyring.Remove(session.Name); err != nil { + log.Printf("Session %q matches profile %q", session.Key, profile) + if err = s.Keyring.Remove(session.Key); err != nil { return n, err } n++ diff --git a/vault/sessions_test.go b/vault/sessions_test.go index 2e7fc7f96..3802eda01 100644 --- a/vault/sessions_test.go +++ b/vault/sessions_test.go @@ -14,6 +14,8 @@ func TestIsSessionKey(t *testing.T) { {"blah", false}, {"blah session (61633665646639303539)", true}, {"blah-iam session (32383863333237616430)", true}, + {"session,c2Vzc2lvbg,,1572281751", true}, + {"session,c2Vzc2lvbg,YXJuOmF3czppYW06OjEyMzQ1Njc4OTA6bWZhL2pzdGV3bW9u,1572281751", true}, } for _, tc := range testCases {