Skip to content

Commit

Permalink
Check source_profile for session name, add debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
lox committed Oct 19, 2015
1 parent 7d63d81 commit 6a892b9
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 18 deletions.
48 changes: 34 additions & 14 deletions provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ type VaultProvider struct {
profiles profiles
session *sts.Credentials
client stsClient
creds map[string]credentials.Value
}

func NewVaultProvider(k keyring.Keyring, profile string, opts VaultOptions) (*VaultProvider, error) {
Expand All @@ -79,6 +80,7 @@ func NewVaultProvider(k keyring.Keyring, profile string, opts VaultOptions) (*Va
keyring: k,
profile: profile,
profiles: profiles,
creds: map[string]credentials.Value{},
}, nil
}

Expand Down Expand Up @@ -115,25 +117,29 @@ func (p *VaultProvider) Retrieve() (credentials.Value, error) {
})
}

log.Printf("Using session, expires in %s", session.Expiration.Sub(time.Now()).String())
log.Printf("Using session ****************%s, expires in %s",
(*session.AccessKeyId)[len(*session.AccessKeyId)-4:],
session.Expiration.Sub(time.Now()).String())

window := p.ExpiryWindow
if window == 0 {
window = p.SessionDuration - (p.SessionDuration / 3)
window = time.Minute * 5
}

p.SetExpiration(*session.Expiration, window)
p.expires = *session.Expiration

if role, ok := p.profiles[p.profile]["role_arn"]; ok {
session, err = p.assumeRole(session, role)
if err != nil {
return credentials.Value{}, err
}

log.Printf("Role token expires in %s", session.Expiration.Sub(time.Now()))
log.Printf("Using role ****************%s, expires in %s",
(*session.AccessKeyId)[len(*session.AccessKeyId)-4:],
session.Expiration.Sub(time.Now()).String())
}

p.SetExpiration(*session.Expiration, window)
p.expires = *session.Expiration

value := credentials.Value{
AccessKeyID: *session.AccessKeyId,
SecretAccessKey: *session.SecretAccessKey,
Expand All @@ -148,11 +154,15 @@ func sessionKey(profile string) string {
}

func (p *VaultProvider) getCachedSession() (session sts.Credentials, err error) {
item, err := p.keyring.Get(sessionKey(p.profile))
source := p.profiles.sourceProfile(p.profile)

item, err := p.keyring.Get(sessionKey(source))
if err != nil {
return session, err
}

log.Printf("Found cached session for profile %s", source)

if err = json.Unmarshal(item.Data, &session); err != nil {
return session, err
}
Expand All @@ -167,13 +177,23 @@ func (p *VaultProvider) getCachedSession() (session sts.Credentials, err error)
func (p *VaultProvider) getMasterCreds() (credentials.Value, error) {
source := p.profiles.sourceProfile(p.profile)

provider := credentials.NewChainCredentials([]credentials.Provider{
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{Filename: "", Profile: p.profile},
&KeyringProvider{Keyring: p.keyring, Profile: source},
})
creds, ok := p.creds[source]
if !ok {
provider := credentials.NewChainCredentials([]credentials.Provider{
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{Filename: "", Profile: p.profile},
&KeyringProvider{Keyring: p.keyring, Profile: source},
})

var err error
if creds, err = provider.Get(); err != nil {
return creds, err
}

p.creds[source] = creds
}

return provider.Get()
return creds, nil
}

func (p *VaultProvider) getSessionToken(creds *credentials.Value) (sts.Credentials, error) {
Expand Down Expand Up @@ -231,7 +251,7 @@ func (p *VaultProvider) assumeRole(session sts.Credentials, roleArn string) (sts
DurationSeconds: aws.Int64(int64(p.AssumeRoleDuration.Seconds())),
}

log.Printf("Assuming role %s, expires in %s", roleArn, p.AssumeRoleDuration.String())
log.Printf("Assuming role %s", roleArn)
resp, err := client.AssumeRole(input)
if err != nil {
return sts.Credentials{}, err
Expand Down
11 changes: 7 additions & 4 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,10 @@ type metadataHandler struct {
}

func indexHandler(w http.ResponseWriter, r *http.Request) {
log.Printf("%s %s", r.Method, r.RequestURI)

fmt.Fprintf(w, "local-credentials")
}

func credentialsHandler(w http.ResponseWriter, r *http.Request) {
log.Printf("%s %s", r.Method, r.RequestURI)

resp, err := http.Get(localServerUrl)
if err != nil {
http.Error(w, err.Error(), http.StatusGatewayTimeout)
Expand Down Expand Up @@ -97,12 +93,19 @@ func startCredentialsServer(ui Ui, creds *VaultCredentials) error {

log.Printf("Local instance role server running on %s", l.Addr())
go http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("Credentials.IsExpired() = %#v", creds.IsExpired())

val, err := creds.Get()
if err != nil {
http.Error(w, err.Error(), http.StatusGatewayTimeout)
return
}

log.Printf("Serving credentials via http ****************%s, expiration of %s (%s)",
val.AccessKeyID[len(val.AccessKeyID)-4:],
creds.Expires().Format(awsTimeFormat),
creds.Expires().Sub(time.Now()).String())

json.NewEncoder(w).Encode(map[string]interface{}{
"Code": "Success",
"LastUpdated": time.Now().Format(awsTimeFormat),
Expand Down

0 comments on commit 6a892b9

Please sign in to comment.