From 6a892b9b0aa8102bc02c8d06135cfec03a1e0d3d Mon Sep 17 00:00:00 2001 From: Lachlan Donald Date: Mon, 19 Oct 2015 11:46:26 +1100 Subject: [PATCH] Check source_profile for session name, add debugging --- provider.go | 48 ++++++++++++++++++++++++++++++++++-------------- server.go | 11 +++++++---- 2 files changed, 41 insertions(+), 18 deletions(-) diff --git a/provider.go b/provider.go index c3de56df0..0cfdb331e 100644 --- a/provider.go +++ b/provider.go @@ -63,6 +63,7 @@ type VaultProvider struct { profiles profiles session *sts.Credentials client stsClient + creds map[string]credentials.Value } func NewVaultProvider(k keyring.Keyring, profile string, opts VaultOptions) (*VaultProvider, error) { @@ -79,6 +80,7 @@ func NewVaultProvider(k keyring.Keyring, profile string, opts VaultOptions) (*Va keyring: k, profile: profile, profiles: profiles, + creds: map[string]credentials.Value{}, }, nil } @@ -115,25 +117,29 @@ func (p *VaultProvider) Retrieve() (credentials.Value, error) { }) } - log.Printf("Using session, expires in %s", session.Expiration.Sub(time.Now()).String()) + log.Printf("Using session ****************%s, expires in %s", + (*session.AccessKeyId)[len(*session.AccessKeyId)-4:], + session.Expiration.Sub(time.Now()).String()) window := p.ExpiryWindow if window == 0 { - window = p.SessionDuration - (p.SessionDuration / 3) + window = time.Minute * 5 } - p.SetExpiration(*session.Expiration, window) - p.expires = *session.Expiration - if role, ok := p.profiles[p.profile]["role_arn"]; ok { session, err = p.assumeRole(session, role) if err != nil { return credentials.Value{}, err } - log.Printf("Role token expires in %s", session.Expiration.Sub(time.Now())) + log.Printf("Using role ****************%s, expires in %s", + (*session.AccessKeyId)[len(*session.AccessKeyId)-4:], + session.Expiration.Sub(time.Now()).String()) } + p.SetExpiration(*session.Expiration, window) + p.expires = *session.Expiration + value := credentials.Value{ AccessKeyID: *session.AccessKeyId, SecretAccessKey: *session.SecretAccessKey, @@ -148,11 +154,15 @@ func sessionKey(profile string) string { } func (p *VaultProvider) getCachedSession() (session sts.Credentials, err error) { - item, err := p.keyring.Get(sessionKey(p.profile)) + source := p.profiles.sourceProfile(p.profile) + + item, err := p.keyring.Get(sessionKey(source)) if err != nil { return session, err } + log.Printf("Found cached session for profile %s", source) + if err = json.Unmarshal(item.Data, &session); err != nil { return session, err } @@ -167,13 +177,23 @@ func (p *VaultProvider) getCachedSession() (session sts.Credentials, err error) func (p *VaultProvider) getMasterCreds() (credentials.Value, error) { source := p.profiles.sourceProfile(p.profile) - provider := credentials.NewChainCredentials([]credentials.Provider{ - &credentials.EnvProvider{}, - &credentials.SharedCredentialsProvider{Filename: "", Profile: p.profile}, - &KeyringProvider{Keyring: p.keyring, Profile: source}, - }) + creds, ok := p.creds[source] + if !ok { + provider := credentials.NewChainCredentials([]credentials.Provider{ + &credentials.EnvProvider{}, + &credentials.SharedCredentialsProvider{Filename: "", Profile: p.profile}, + &KeyringProvider{Keyring: p.keyring, Profile: source}, + }) + + var err error + if creds, err = provider.Get(); err != nil { + return creds, err + } + + p.creds[source] = creds + } - return provider.Get() + return creds, nil } func (p *VaultProvider) getSessionToken(creds *credentials.Value) (sts.Credentials, error) { @@ -231,7 +251,7 @@ func (p *VaultProvider) assumeRole(session sts.Credentials, roleArn string) (sts DurationSeconds: aws.Int64(int64(p.AssumeRoleDuration.Seconds())), } - log.Printf("Assuming role %s, expires in %s", roleArn, p.AssumeRoleDuration.String()) + log.Printf("Assuming role %s", roleArn) resp, err := client.AssumeRole(input) if err != nil { return sts.Credentials{}, err diff --git a/server.go b/server.go index 55e244f73..ed16a20db 100644 --- a/server.go +++ b/server.go @@ -46,14 +46,10 @@ type metadataHandler struct { } func indexHandler(w http.ResponseWriter, r *http.Request) { - log.Printf("%s %s", r.Method, r.RequestURI) - fmt.Fprintf(w, "local-credentials") } func credentialsHandler(w http.ResponseWriter, r *http.Request) { - log.Printf("%s %s", r.Method, r.RequestURI) - resp, err := http.Get(localServerUrl) if err != nil { http.Error(w, err.Error(), http.StatusGatewayTimeout) @@ -97,12 +93,19 @@ func startCredentialsServer(ui Ui, creds *VaultCredentials) error { log.Printf("Local instance role server running on %s", l.Addr()) go http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Printf("Credentials.IsExpired() = %#v", creds.IsExpired()) + val, err := creds.Get() if err != nil { http.Error(w, err.Error(), http.StatusGatewayTimeout) return } + log.Printf("Serving credentials via http ****************%s, expiration of %s (%s)", + val.AccessKeyID[len(val.AccessKeyID)-4:], + creds.Expires().Format(awsTimeFormat), + creds.Expires().Sub(time.Now()).String()) + json.NewEncoder(w).Encode(map[string]interface{}{ "Code": "Success", "LastUpdated": time.Now().Format(awsTimeFormat),