Skip to content

Commit

Permalink
Check session keys and re-fetch if they are expired
Browse files Browse the repository at this point in the history
  • Loading branch information
lox committed Sep 3, 2015
1 parent 2e90ce2 commit 6feda5d
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions provider.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"errors"
"fmt"
"log"
"time"
Expand All @@ -9,7 +10,6 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/bgentry/speakeasy"
)

const (
Expand Down Expand Up @@ -48,10 +48,9 @@ func NewVaultProvider(k keyring.Keyring, profile string) (*VaultProvider, error)
}

func (p *VaultProvider) Retrieve() (credentials.Value, error) {
var session sts.Credentials

if err := keyring.Unmarshal(p.Keyring, sessionServiceName, p.Profile, &session); err != nil {
log.Println("Session lookup failed", err)
session, err := p.getCachedSession()
if err != nil {
log.Println(err)

session, err = p.getSessionToken(p.SessionDuration)
if err != nil {
Expand All @@ -66,8 +65,6 @@ func (p *VaultProvider) Retrieve() (credentials.Value, error) {
}

keyring.Marshal(p.Keyring, sessionServiceName, p.Profile, session)
} else {
log.Printf("Found a cached session token for %s", p.Profile)
}

log.Printf("Session token expires in %s", session.Expiration.Sub(time.Now()))
Expand All @@ -82,6 +79,18 @@ func (p *VaultProvider) Retrieve() (credentials.Value, error) {
return value, nil
}

func (p *VaultProvider) getCachedSession() (session sts.Credentials, err error) {
if err = keyring.Unmarshal(p.Keyring, sessionServiceName, p.Profile, &session); err != nil {
return session, err
}

if session.Expiration.Before(time.Now()) {
return session, errors.New("Session is expired")
}

return
}

func (p *VaultProvider) getSessionToken(length time.Duration) (sts.Credentials, error) {
source := p.profilesConf.sourceProfile(p.Profile)

Expand All @@ -90,7 +99,7 @@ func (p *VaultProvider) getSessionToken(length time.Duration) (sts.Credentials,
}

if mfa, ok := p.profilesConf[p.Profile]["mfa_serial"]; ok {
token, err := speakeasy.Ask(fmt.Sprintf("Enter token for %s: ", mfa))
token, err := promptPassword(fmt.Sprintf("Enter token for %s: ", mfa))
if err != nil {
return sts.Credentials{}, err
}
Expand Down

0 comments on commit 6feda5d

Please sign in to comment.