From bfa952d72a0aa5dc575f339115adc3a11fe0e115 Mon Sep 17 00:00:00 2001 From: Michael Tibben Date: Mon, 20 Mar 2023 10:26:38 +1100 Subject: [PATCH] Check session identity when creds are static --- cli/login.go | 204 +++++++++++++++------ vault/assumeroleprovider.go | 4 +- vault/assumerolewithwebidentityprovider.go | 4 +- vault/cachedsessionprovider.go | 28 ++- vault/credentialprocessprovider.go | 2 +- vault/sessiontokenprovider.go | 4 +- vault/ssorolecredentialsprovider.go | 4 + vault/vault.go | 39 ++-- 8 files changed, 195 insertions(+), 94 deletions(-) diff --git a/cli/login.go b/cli/login.go index 1977b4366..7e248d7e9 100644 --- a/cli/login.go +++ b/cli/login.go @@ -17,6 +17,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" awsconfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/skratchdot/open-golang/open" ) @@ -74,108 +75,99 @@ func ConfigureLoginCommand(app *kingpin.Application, a *AwsVault) { return err } - err = LoginCommand(input, f, keyring) + err = LoginCommand(context.Background(), input, f, keyring) app.FatalIfError(err, "login") return nil }) } -func LoginCommand(input LoginCommandInput, f *vault.ConfigFile, keyring keyring.Keyring) error { - config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).GetProfileConfig(input.ProfileName) - if err != nil { - return fmt.Errorf("Error loading config: %w", err) - } - - var credsProvider aws.CredentialsProvider - +func getCredsProvider(input LoginCommandInput, config *vault.ProfileConfig, keyring keyring.Keyring) (credsProvider aws.CredentialsProvider, err error) { if input.ProfileName == "" { // When no profile is specified, source credentials from the environment configFromEnv, err := awsconfig.NewEnvConfig() if err != nil { - return fmt.Errorf("unable to authenticate to AWS through your environment variables: %w", err) + return nil, fmt.Errorf("unable to authenticate to AWS through your environment variables: %w", err) } - credsProvider = credentials.StaticCredentialsProvider{Value: configFromEnv.Credentials} - if configFromEnv.Credentials.SessionToken == "" { - credsProvider, err = vault.NewFederationTokenProvider(context.TODO(), credsProvider, config) - if err != nil { - return err - } + + if configFromEnv.Credentials.AccessKeyID == "" { + return nil, fmt.Errorf("argument 'profile' not provided, nor any AWS env vars found. Try --help") } + + credsProvider = credentials.StaticCredentialsProvider{Value: configFromEnv.Credentials} } else { // Use a profile from the AWS config file ckr := &vault.CredentialKeyring{Keyring: keyring} - if config.HasRole() || config.HasSSOStartURL() || config.HasCredentialProcess() || config.HasWebIdentity() { - // If AssumeRole or sso.GetRoleCredentials isn't used, GetFederationToken has to be used for IAM credentials - credsProvider, err = vault.NewTempCredentialsProvider(config, ckr, input.NoSession, false) - } else { - credsProvider, err = vault.NewFederationTokenCredentialsProvider(context.TODO(), input.ProfileName, ckr, config) + t := vault.TempCredentialsCreator{ + Keyring: ckr, + DisableSessions: input.NoSession, + DisableSessionsForProfile: config.ProfileName, } + credsProvider, err = t.GetProviderForProfile(config) if err != nil { - return fmt.Errorf("profile %s: %w", input.ProfileName, err) + return nil, fmt.Errorf("profile %s: %w", input.ProfileName, err) } } - creds, err := credsProvider.Retrieve(context.TODO()) + return credsProvider, err +} + +// LoginCommand creates a login URL for the AWS Management Console using the method described at +// https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_providers_enable-console-custom-url.html +func LoginCommand(ctx context.Context, input LoginCommandInput, f *vault.ConfigFile, keyring keyring.Keyring) error { + config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).GetProfileConfig(input.ProfileName) if err != nil { - return fmt.Errorf("Failed to get credentials: %w", err) - } - if creds.AccessKeyID == "" && input.ProfileName == "" { - return fmt.Errorf("argument 'profile' not provided, nor any AWS env vars found. Try --help") + return fmt.Errorf("Error loading config: %w", err) } - jsonBytes, err := json.Marshal(map[string]string{ - "sessionId": creds.AccessKeyID, - "sessionKey": creds.SecretAccessKey, - "sessionToken": creds.SessionToken, - }) + credsProvider, err := getCredsProvider(input, config, keyring) if err != nil { return err } - loginURLPrefix, destination := generateLoginURL(config.Region, input.Path) - - req, err := http.NewRequestWithContext(context.TODO(), "GET", loginURLPrefix, nil) + // if we already know the type of credentials being created, avoid calling isCallerIdentityAssumedRole + canCredsBeUsedInLoginURL, err := canProviderBeUsedForLogin(credsProvider) if err != nil { return err } - if creds.CanExpire { - log.Printf("Creating login token, expires in %s", time.Until(creds.Expires)) - } + if !canCredsBeUsedInLoginURL { + // use a static creds provider so that we don't request credentials from AWS more than once + credsProvider, err = createStaticCredentialsProvider(ctx, credsProvider) + if err != nil { + return err + } - q := req.URL.Query() - q.Add("Action", "getSigninToken") - q.Add("Session", string(jsonBytes)) - req.URL.RawQuery = q.Encode() + // if the credentials have come from an unknown source like credential_process, check the + // caller identity to see if it's an assumed role + isAssumedRole, err := isCallerIdentityAssumedRole(ctx, credsProvider, config) + if err != nil { + return err + } - resp, err := http.DefaultClient.Do(req) - if err != nil { - return err + if !isAssumedRole { + log.Println("Creating a federated session") + credsProvider, err = vault.NewFederationTokenProvider(ctx, credsProvider, config) + if err != nil { + return err + } + } } - defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) + creds, err := credsProvider.Retrieve(ctx) if err != nil { return err } - if resp.StatusCode != http.StatusOK { - log.Printf("Response body was %s", body) - return fmt.Errorf("Call to getSigninToken failed with %v", resp.Status) + if creds.CanExpire { + log.Printf("Requesting a signin token for session expiring in %s", time.Until(creds.Expires)) } - var respParsed map[string]string - - err = json.Unmarshal(body, &respParsed) + loginURLPrefix, destination := generateLoginURL(config.Region, input.Path) + signinToken, err := requestSigninToken(ctx, creds, loginURLPrefix) if err != nil { return err } - signinToken, ok := respParsed["SigninToken"] - if !ok { - return fmt.Errorf("Expected a response with SigninToken") - } - loginURL := fmt.Sprintf("%s?Action=login&Issuer=aws-vault&Destination=%s&SigninToken=%s", loginURLPrefix, url.QueryEscape(destination), url.QueryEscape(signinToken)) @@ -212,3 +204,99 @@ func generateLoginURL(region string, path string) (string, string) { } return loginURLPrefix, destination } + +func isCallerIdentityAssumedRole(ctx context.Context, credsProvider aws.CredentialsProvider, config *vault.ProfileConfig) (bool, error) { + cfg := vault.NewAwsConfigWithCredsProvider(credsProvider, config.Region, config.STSRegionalEndpoints) + client := sts.NewFromConfig(cfg) + id, err := client.GetCallerIdentity(ctx, nil) + if err != nil { + return false, err + } + arn := aws.ToString(id.Arn) + arnParts := strings.Split(arn, ":") + if len(arnParts) < 6 { + return false, fmt.Errorf("unable to parse ARN: %s", arn) + } + if strings.HasPrefix(arnParts[5], "assumed-role") { + return true, nil + } + return false, nil +} + +func createStaticCredentialsProvider(ctx context.Context, credsProvider aws.CredentialsProvider) (sc credentials.StaticCredentialsProvider, err error) { + creds, err := credsProvider.Retrieve(ctx) + if err != nil { + return sc, err + } + return credentials.StaticCredentialsProvider{Value: creds}, nil +} + +// canProviderBeUsedForLogin returns true if the credentials produced by the provider is known to be usable by the login URL endpoint +func canProviderBeUsedForLogin(credsProvider aws.CredentialsProvider) (bool, error) { + if _, ok := credsProvider.(*vault.AssumeRoleProvider); ok { + return true, nil + } + if _, ok := credsProvider.(*vault.SSORoleCredentialsProvider); ok { + return true, nil + } + if _, ok := credsProvider.(*vault.AssumeRoleWithWebIdentityProvider); ok { + return true, nil + } + if c, ok := credsProvider.(*vault.CachedSessionProvider); ok { + return canProviderBeUsedForLogin(c.SessionProvider) + } + + return false, nil +} + +// Create a signin token +func requestSigninToken(ctx context.Context, creds aws.Credentials, loginURLPrefix string) (string, error) { + jsonSession, err := json.Marshal(map[string]string{ + "sessionId": creds.AccessKeyID, + "sessionKey": creds.SecretAccessKey, + "sessionToken": creds.SessionToken, + }) + if err != nil { + return "", err + } + + req, err := http.NewRequestWithContext(ctx, "GET", loginURLPrefix, nil) + if err != nil { + return "", err + } + + q := req.URL.Query() + q.Add("Action", "getSigninToken") + q.Add("Session", string(jsonSession)) + req.URL.RawQuery = q.Encode() + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + if resp.StatusCode != http.StatusOK { + log.Printf("Response body was %s", body) + return "", fmt.Errorf("Call to getSigninToken failed with %v", resp.Status) + } + + var respParsed map[string]string + + err = json.Unmarshal(body, &respParsed) + if err != nil { + return "", err + } + + signinToken, ok := respParsed["SigninToken"] + if !ok { + return "", fmt.Errorf("Expected a response with SigninToken") + } + + return signinToken, nil +} diff --git a/vault/assumeroleprovider.go b/vault/assumeroleprovider.go index d482fb33d..87db054ac 100644 --- a/vault/assumeroleprovider.go +++ b/vault/assumeroleprovider.go @@ -26,7 +26,7 @@ type AssumeRoleProvider struct { // Retrieve generates a new set of temporary credentials using STS AssumeRole func (p *AssumeRoleProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { - role, err := p.assumeRole(ctx) + role, err := p.RetrieveStsCredentials(ctx) if err != nil { return aws.Credentials{}, err } @@ -49,7 +49,7 @@ func (p *AssumeRoleProvider) roleSessionName() string { return p.RoleSessionName } -func (p *AssumeRoleProvider) assumeRole(ctx context.Context) (*ststypes.Credentials, error) { +func (p *AssumeRoleProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) { var err error input := &sts.AssumeRoleInput{ diff --git a/vault/assumerolewithwebidentityprovider.go b/vault/assumerolewithwebidentityprovider.go index b5faa1b01..ec2537eb8 100644 --- a/vault/assumerolewithwebidentityprovider.go +++ b/vault/assumerolewithwebidentityprovider.go @@ -25,7 +25,7 @@ type AssumeRoleWithWebIdentityProvider struct { // Retrieve generates a new set of temporary credentials using STS AssumeRoleWithWebIdentity func (p *AssumeRoleWithWebIdentityProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { - creds, err := p.assumeRole(ctx) + creds, err := p.RetrieveStsCredentials(ctx) if err != nil { return aws.Credentials{}, err } @@ -48,7 +48,7 @@ func (p *AssumeRoleWithWebIdentityProvider) roleSessionName() string { return p.RoleSessionName } -func (p *AssumeRoleWithWebIdentityProvider) assumeRole(ctx context.Context) (*ststypes.Credentials, error) { +func (p *AssumeRoleWithWebIdentityProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) { var err error webIdentityToken, err := p.webIdentityToken() diff --git a/vault/cachedsessionprovider.go b/vault/cachedsessionprovider.go index 900197b1b..1a382d6b3 100644 --- a/vault/cachedsessionprovider.go +++ b/vault/cachedsessionprovider.go @@ -9,34 +9,48 @@ import ( ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" ) +type StsSessionProvider interface { + aws.CredentialsProvider + RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) +} + // CachedSessionProvider retrieves cached credentials from the keyring, or if no credentials are cached // retrieves temporary credentials using the CredentialsFunc type CachedSessionProvider struct { SessionKey SessionMetadata - CredentialsFunc func(context.Context) (*ststypes.Credentials, error) + SessionProvider StsSessionProvider Keyring *SessionKeyring ExpiryWindow time.Duration } -// Retrieve returns cached credentials from the keyring, or if no credentials are cached -// generates a new set of temporary credentials using the CredentialsFunc -func (p *CachedSessionProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { +func (p *CachedSessionProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) { creds, err := p.Keyring.Get(p.SessionKey) if err != nil || time.Until(*creds.Expiration) < p.ExpiryWindow { // lookup missed, we need to create a new one. - creds, err = p.CredentialsFunc(ctx) + creds, err = p.SessionProvider.RetrieveStsCredentials(ctx) if err != nil { - return aws.Credentials{}, err + return nil, err } err = p.Keyring.Set(p.SessionKey, creds) if err != nil { - return aws.Credentials{}, err + return nil, err } } else { log.Printf("Re-using cached credentials %s from %s, expires in %s", FormatKeyForDisplay(*creds.AccessKeyId), p.SessionKey.Type, time.Until(*creds.Expiration).String()) } + return creds, nil +} + +// Retrieve returns cached credentials from the keyring, or if no credentials are cached +// generates a new set of temporary credentials using the CredentialsFunc +func (p *CachedSessionProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { + creds, err := p.RetrieveStsCredentials(ctx) + if err != nil { + return aws.Credentials{}, err + } + return aws.Credentials{ AccessKeyID: aws.ToString(creds.AccessKeyId), SecretAccessKey: aws.ToString(creds.SecretAccessKey), diff --git a/vault/credentialprocessprovider.go b/vault/credentialprocessprovider.go index 9006a7ca4..d9a7d00fe 100644 --- a/vault/credentialprocessprovider.go +++ b/vault/credentialprocessprovider.go @@ -55,7 +55,7 @@ func (p *CredentialProcessProvider) retrieveWith(ctx context.Context, fn func(st }, nil } -func (p *CredentialProcessProvider) callCredentialProcess(ctx context.Context) (*ststypes.Credentials, error) { +func (p *CredentialProcessProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) { return p.callCredentialProcessWith(ctx, executeProcess) } diff --git a/vault/sessiontokenprovider.go b/vault/sessiontokenprovider.go index e19634ccf..b58179d0f 100644 --- a/vault/sessiontokenprovider.go +++ b/vault/sessiontokenprovider.go @@ -19,7 +19,7 @@ type SessionTokenProvider struct { // Retrieve generates a new set of temporary credentials using STS GetSessionToken func (p *SessionTokenProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { - creds, err := p.GetSessionToken(ctx) + creds, err := p.RetrieveStsCredentials(ctx) if err != nil { return aws.Credentials{}, err } @@ -34,7 +34,7 @@ func (p *SessionTokenProvider) Retrieve(ctx context.Context) (aws.Credentials, e } // GetSessionToken generates a new set of temporary credentials using STS GetSessionToken -func (p *SessionTokenProvider) GetSessionToken(ctx context.Context) (*ststypes.Credentials, error) { +func (p *SessionTokenProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) { var err error input := &sts.GetSessionTokenInput{ diff --git a/vault/ssorolecredentialsprovider.go b/vault/ssorolecredentialsprovider.go index 0f1c8c8bb..25c008b61 100644 --- a/vault/ssorolecredentialsprovider.go +++ b/vault/ssorolecredentialsprovider.go @@ -93,6 +93,10 @@ func (p *SSORoleCredentialsProvider) getRoleCredentials(ctx context.Context) (*s return resp.RoleCredentials, nil } +func (p *SSORoleCredentialsProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) { + return p.getRoleCredentialsAsStsCredemtials(ctx) +} + // getRoleCredentialsAsStsCredemtials returns getRoleCredentials as sts.Credentials because sessions.Store expects it func (p *SSORoleCredentialsProvider) getRoleCredentialsAsStsCredemtials(ctx context.Context) (*ststypes.Credentials, error) { creds, err := p.getRoleCredentials(ctx) diff --git a/vault/vault.go b/vault/vault.go index 5480dad25..9f3d4a6e7 100644 --- a/vault/vault.go +++ b/vault/vault.go @@ -64,7 +64,7 @@ func NewSessionTokenProvider(credsProvider aws.CredentialsProvider, k keyring.Ke }, Keyring: &SessionKeyring{Keyring: k}, ExpiryWindow: defaultExpirationWindow, - CredentialsFunc: sessionTokenProvider.GetSessionToken, + SessionProvider: sessionTokenProvider, }, nil } @@ -96,7 +96,7 @@ func NewAssumeRoleProvider(credsProvider aws.CredentialsProvider, k keyring.Keyr }, Keyring: &SessionKeyring{Keyring: k}, ExpiryWindow: defaultExpirationWindow, - CredentialsFunc: p.assumeRole, + SessionProvider: p, }, nil } @@ -125,7 +125,7 @@ func NewAssumeRoleWithWebIdentityProvider(k keyring.Keyring, config *ProfileConf }, Keyring: &SessionKeyring{Keyring: k}, ExpiryWindow: defaultExpirationWindow, - CredentialsFunc: p.assumeRole, + SessionProvider: p, }, nil } @@ -155,7 +155,7 @@ func NewSSORoleCredentialsProvider(k keyring.Keyring, config *ProfileConfig, use }, Keyring: &SessionKeyring{Keyring: k}, ExpiryWindow: defaultExpirationWindow, - CredentialsFunc: ssoRoleCredentialsProvider.getRoleCredentialsAsStsCredemtials, + SessionProvider: ssoRoleCredentialsProvider, }, nil } @@ -177,27 +177,17 @@ func NewCredentialProcessProvider(k keyring.Keyring, config *ProfileConfig, useS }, Keyring: &SessionKeyring{Keyring: k}, ExpiryWindow: defaultExpirationWindow, - CredentialsFunc: credentialProcessProvider.callCredentialProcess, + SessionProvider: credentialProcessProvider, }, nil } return credentialProcessProvider, nil } -func NewFederationTokenCredentialsProvider(ctx context.Context, profileName string, k *CredentialKeyring, config *ProfileConfig) (aws.CredentialsProvider, error) { - credentialsName, err := FindMasterCredentialsNameFor(profileName, k, config) - if err != nil { - return nil, err - } - masterCreds := NewMasterCredentialsProvider(k, credentialsName) - - return NewFederationTokenProvider(ctx, masterCreds, config) -} - func NewFederationTokenProvider(ctx context.Context, credsProvider aws.CredentialsProvider, config *ProfileConfig) (*FederationTokenProvider, error) { cfg := NewAwsConfigWithCredsProvider(credsProvider, config.Region, config.STSRegionalEndpoints) - currentUsername, err := GetUsernameFromSession(ctx, cfg) + name, err := GetUsernameFromSession(ctx, cfg) if err != nil { return nil, err } @@ -205,7 +195,7 @@ func NewFederationTokenProvider(ctx context.Context, credsProvider aws.Credentia log.Printf("Using GetFederationToken for credentials") return &FederationTokenProvider{ StsClient: sts.NewFromConfig(cfg), - Name: currentUsername, + Name: name, Duration: config.GetFederationTokenDuration, }, nil } @@ -227,17 +217,19 @@ func FindMasterCredentialsNameFor(profileName string, keyring *CredentialKeyring return FindMasterCredentialsNameFor(config.SourceProfileName, keyring, config) } -type tempCredsCreator struct { +type TempCredentialsCreator struct { Keyring *CredentialKeyring // DisableSessions will disable the use of GetSessionToken DisableSessions bool // DisableCache will disable the use of the session cache DisableCache bool + // DisableSessionsForProfile is a profile for which sessions should not be used + DisableSessionsForProfile string chainedMfa string } -func (t *tempCredsCreator) getSourceCreds(config *ProfileConfig, hasStoredCredentials bool) (sourcecredsProvider aws.CredentialsProvider, err error) { +func (t *TempCredentialsCreator) getSourceCreds(config *ProfileConfig, hasStoredCredentials bool) (sourcecredsProvider aws.CredentialsProvider, err error) { if hasStoredCredentials { log.Printf("profile %s: using stored credentials", config.ProfileName) return NewMasterCredentialsProvider(t.Keyring, config.ProfileName), nil @@ -251,7 +243,7 @@ func (t *tempCredsCreator) getSourceCreds(config *ProfileConfig, hasStoredCreden return nil, fmt.Errorf("profile %s: credentials missing", config.ProfileName) } -func (t *tempCredsCreator) GetProviderForProfile(config *ProfileConfig) (aws.CredentialsProvider, error) { +func (t *TempCredentialsCreator) GetProviderForProfile(config *ProfileConfig) (aws.CredentialsProvider, error) { hasStoredCredentials, err := t.Keyring.Has(config.ProfileName) if err != nil { return nil, err @@ -300,10 +292,13 @@ func (t *tempCredsCreator) GetProviderForProfile(config *ProfileConfig) (aws.Cre } // canUseGetSessionToken determines if GetSessionToken should be used, and if not returns a reason -func (t *tempCredsCreator) canUseGetSessionToken(c *ProfileConfig) (bool, string) { +func (t *TempCredentialsCreator) canUseGetSessionToken(c *ProfileConfig) (bool, string) { if t.DisableSessions { return false, "sessions are disabled" } + if t.DisableSessionsForProfile == c.ProfileName { + return false, "sessions are disabled for this profile" + } if c.IsChained() { if !c.ChainedFromProfile.HasMfaSerial() { @@ -338,7 +333,7 @@ func mfaDetails(mfaChained bool, config *ProfileConfig) string { // NewTempCredentialsProvider creates a credential provider for the given config func NewTempCredentialsProvider(config *ProfileConfig, keyring *CredentialKeyring, disableSessions bool, disableCache bool) (aws.CredentialsProvider, error) { - t := tempCredsCreator{ + t := TempCredentialsCreator{ Keyring: keyring, DisableSessions: disableSessions, DisableCache: disableCache,