Skip to content

Commit

Permalink
Merge pull request #440 from jstewmon/mfa-session-cache-key
Browse files Browse the repository at this point in the history
fix: session cache key should vary with mfa serial
  • Loading branch information
mtibben committed Oct 30, 2019
2 parents 8ae71d8 + ad2e94a commit 08380e6
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 36 deletions.
10 changes: 7 additions & 3 deletions cli/ls.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,19 @@ func LsCommand(app *kingpin.Application, input LsCommandInput) {
fmt.Fprintf(w, "-\t")
}

var sessionIDs []string
var sessionLabels []string
for _, sess := range sessions {
if profile.Name == sess.Profile.Name {
sessionIDs = append(sessionIDs, sess.SessionID)
label := fmt.Sprintf("%d", sess.Expiration.Unix())
if sess.MfaSerial != "" {
label += " (mfa)"
}
sessionLabels = append(sessionLabels, label)
}
}

if len(sessions) > 0 {
fmt.Fprintf(w, "%s\t\n", strings.Join(sessionIDs, ", "))
fmt.Fprintf(w, "%s\t\n", strings.Join(sessionLabels, ", "))
} else {
fmt.Fprintf(w, "-\t\n")
}
Expand Down
4 changes: 2 additions & 2 deletions vault/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (p *VaultProvider) Retrieve() (credentials.Value, error) {
}

// sessions get stored by profile, not the source
session, err := p.sessions.Retrieve(p.profile)
session, err := p.sessions.Retrieve(p.profile, p.VaultOptions.MfaSerial)
if err != nil {
if err == keyring.ErrKeyNotFound {
log.Printf("Session not found in keyring for %s", p.profile)
Expand Down Expand Up @@ -138,7 +138,7 @@ func (p *VaultProvider) Retrieve() (credentials.Value, error) {
return credentials.Value{}, err
}

if err = p.sessions.Store(p.profile, session, time.Now().Add(p.SessionDuration)); err != nil {
if err = p.sessions.Store(p.profile, p.VaultOptions.MfaSerial, session); err != nil {
return credentials.Value{}, err
}
}
Expand Down
94 changes: 63 additions & 31 deletions vault/sessions.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package vault

import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
Expand All @@ -13,39 +14,70 @@ import (
"github.com/aws/aws-sdk-go/service/sts"
)

var sessionKeyPattern = regexp.MustCompile(`^(.+?) session \((\d+)\)$`)
var sessionKeyPattern = regexp.MustCompile(`^session,(?P<profile>[^,]+),(?P<mfaSerial>[^,]*),(?P<expiration>[^:]+)$`)
var oldSessionKeyPatterns = []*regexp.Regexp{
regexp.MustCompile(`^session:(?P<profile>[^ ]+):(?P<mfaSerial>[^ ]*):(?P<expiration>[^:]+)$`),
regexp.MustCompile(`^(.+?) session \((\d+)\)$`),
}
var base64Encoding = base64.URLEncoding.WithPadding(base64.NoPadding)

func IsSessionKey(s string) bool {
return sessionKeyPattern.MatchString(s)
if sessionKeyPattern.MatchString(s) {
return true
}
for _, pattern := range oldSessionKeyPatterns {
if pattern.MatchString(s) {
return true
}
}
return false
}

func parseKeyringSession(s string, conf *Config) (KeyringSession, error) {
matches := sessionKeyPattern.FindStringSubmatch(s)
func parseSessionKey(key string, conf *Config) (KeyringSession, error) {
matches := sessionKeyPattern.FindStringSubmatch(key)
if len(matches) == 0 {
return KeyringSession{}, errors.New("Failed to parse session name")
return KeyringSession{}, errors.New("failed to parse session name")
}
profileName, err := base64Encoding.DecodeString(matches[1])
if err != nil {
return KeyringSession{}, err
}
mfaSerial, err := base64Encoding.DecodeString(matches[2])
if err != nil {
return KeyringSession{}, err
}
tsInt, err := strconv.ParseInt(matches[3], 10, 64)
if err != nil {
return KeyringSession{}, err
}
profile, _ := conf.Profile(matches[1])
return KeyringSession{Profile: profile, Name: s, SessionID: matches[2]}, nil
profile, _ := conf.Profile(string(profileName))
return KeyringSession{
Profile: profile,
Key: key,
Expiration: time.Unix(tsInt, 0),
MfaSerial: string(mfaSerial),
}, nil
}

func formatSessionKey(profile string, mfaSerial string, expiration *time.Time) string {
return fmt.Sprintf(
"session,%s,%s,%d",
base64Encoding.EncodeToString([]byte(profile)),
base64Encoding.EncodeToString([]byte(mfaSerial)),
expiration.Unix(),
)
}

type KeyringSession struct {
Profile
Name string
SessionID string
Key string
Expiration time.Time
MfaSerial string
}

func (ks KeyringSession) IsExpired() bool {
// Older sessions were 20 characters long and opaque identifiers
if len(ks.SessionID) == 20 {
return true
}
// Now our session id's are timestamps
tsInt, err := strconv.ParseInt(ks.SessionID, 10, 64)
if err != nil {
return true
}
log.Printf("Session %q expires in %v", ks.Name, time.Unix(tsInt, 0).Sub(time.Now()).String())
return time.Now().After(time.Unix(tsInt, 0))
log.Printf("Session %q expires in %v", ks.Key, ks.Expiration.Sub(time.Now()).String())
return time.Now().After(ks.Expiration)
}

type KeyringSessions struct {
Expand All @@ -71,9 +103,9 @@ func (s *KeyringSessions) Sessions() ([]KeyringSession, error) {

for _, k := range keys {
if IsSessionKey(k) {
ks, _ := parseKeyringSession(k, s.Config)
if ks.IsExpired() {
log.Printf("Session %s is expired, attempting deleting", k)
ks, err := parseSessionKey(k, s.Config)
if err != nil || ks.IsExpired() {
log.Printf("Session %s is obsolete, attempting deleting", k)
if err := s.Keyring.Remove(k); err != nil {
log.Printf("Error deleting session: %v", err)
}
Expand All @@ -88,16 +120,16 @@ func (s *KeyringSessions) Sessions() ([]KeyringSession, error) {
}

// Retrieve searches sessions for specific profile, expects the profile to be provided, not the source
func (s *KeyringSessions) Retrieve(profile string) (creds sts.Credentials, err error) {
func (s *KeyringSessions) Retrieve(profile string, mfaSerial string) (creds sts.Credentials, err error) {
log.Printf("Looking for sessions for %s", profile)
sessions, err := s.Sessions()
if err != nil {
return creds, err
}

for _, session := range sessions {
if session.Profile.Name == profile {
item, err := s.Keyring.Get(session.Name)
if session.Profile.Name == profile && session.MfaSerial == mfaSerial {
item, err := s.Keyring.Get(session.Key)
if err != nil {
return creds, err
}
Expand All @@ -108,7 +140,7 @@ func (s *KeyringSessions) Retrieve(profile string) (creds sts.Credentials, err e

// double check the actual expiry time
if creds.Expiration.Before(time.Now()) {
log.Printf("Session %q is expired, deleting", session.Name)
log.Printf("Session %q is expired, deleting", session.Key)
if err = s.Keyring.Remove(session.Profile.Name); err != nil {
return creds, err
}
Expand All @@ -123,13 +155,13 @@ func (s *KeyringSessions) Retrieve(profile string) (creds sts.Credentials, err e
}

// Store stores a sessions for a specific profile, expects the profile to be provided, not the source
func (s *KeyringSessions) Store(profile string, session sts.Credentials, expires time.Time) error {
func (s *KeyringSessions) Store(profile string, mfaSerial string, session sts.Credentials) error {
bytes, err := json.Marshal(session)
if err != nil {
return err
}

key := fmt.Sprintf("%s session (%d)", profile, expires.Unix())
key := formatSessionKey(profile, mfaSerial, session.Expiration)
log.Printf("Writing session for %s to keyring: %q", profile, key)

return s.Keyring.Set(keyring.Item{
Expand All @@ -153,8 +185,8 @@ func (s *KeyringSessions) Delete(profile string) (n int, err error) {

for _, session := range sessions {
if session.Profile.Name == profile {
log.Printf("Session %q matches profile %q", session.Name, profile)
if err = s.Keyring.Remove(session.Name); err != nil {
log.Printf("Session %q matches profile %q", session.Key, profile)
if err = s.Keyring.Remove(session.Key); err != nil {
return n, err
}
n++
Expand Down
2 changes: 2 additions & 0 deletions vault/sessions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ func TestIsSessionKey(t *testing.T) {
{"blah", false},
{"blah session (61633665646639303539)", true},
{"blah-iam session (32383863333237616430)", true},
{"session,c2Vzc2lvbg,,1572281751", true},
{"session,c2Vzc2lvbg,YXJuOmF3czppYW06OjEyMzQ1Njc4OTA6bWZhL2pzdGV3bW9u,1572281751", true},
}

for _, tc := range testCases {
Expand Down

0 comments on commit 08380e6

Please sign in to comment.