Skip to content

Commit

Permalink
Merge pull request #662 from 99designs/sts_regional_endpoints
Browse files Browse the repository at this point in the history
Support sts_regional_endpoints configuration
  • Loading branch information
mtibben committed Sep 19, 2020
2 parents 0b9e2d4 + 6f3764d commit 1afe60d
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 13 deletions.
1 change: 1 addition & 0 deletions USAGE.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ To configure the default flag values of `aws-vault` and its subcommands:
To override the AWS config file (used in the `exec`, `login` and `rotate` subcommands):
* `AWS_REGION`: The AWS region
* `AWS_DEFAULT_REGION`: The AWS region, applied only if `AWS_REGION` isn't set
* `AWS_STS_REGIONAL_ENDPOINTS`: STS endpoint resolution logic, must be "regional" or "legacy"
* `AWS_MFA_SERIAL`: The identification number of the MFA device to use
* `AWS_ROLE_ARN`: Specifies the ARN of an IAM role in the active profile
* `AWS_ROLE_SESSION_NAME`: Specifies the name to attach to the role session in the active profile
Expand Down
2 changes: 1 addition & 1 deletion cli/rotate.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func RotateCommand(input RotateCommandInput, f *vault.ConfigFile, keyring keyrin
}
}

sess, err := vault.NewSessionWithCreds(sessCreds, config.Region)
sess, err := vault.NewSessionWithCreds(sessCreds, config.Region, config.STSRegionalEndpoints)
if err != nil {
return err
}
Expand Down
2 changes: 2 additions & 0 deletions vault/assumeroleprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ func (p *AssumeRoleProvider) assumeRole() (*sts.Credentials, error) {
}
}

log.Printf("Using STS endpoint %s", p.StsClient.Endpoint)

resp, err := p.StsClient.AssumeRole(input)
if err != nil {
return nil, err
Expand Down
12 changes: 12 additions & 0 deletions vault/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ type ProfileSection struct {
SSORoleName string `ini:"sso_role_name,omitempty"`
WebIdentityTokenFile string `ini:"web_identity_token_file,omitempty"`
WebIdentityTokenProcess string `ini:"web_identity_token_process,omitempty"`
STSRegionalEndpoints string `ini:"sts_regional_endpoints,omitempty"`
}

func (s ProfileSection) IsEmpty() bool {
Expand Down Expand Up @@ -322,6 +323,9 @@ func (cl *ConfigLoader) populateFromConfigFile(config *Config, profileName strin
if config.WebIdentityTokenProcess == "" {
config.WebIdentityTokenProcess = psection.WebIdentityTokenProcess
}
if config.STSRegionalEndpoints == "" {
config.STSRegionalEndpoints = psection.STSRegionalEndpoints
}

if psection.ParentProfile != "" {
fmt.Fprint(os.Stderr, "Warning: parent_profile is deprecated, please use include_profile instead in your AWS config\n")
Expand Down Expand Up @@ -363,6 +367,11 @@ func (cl *ConfigLoader) populateFromEnv(profile *Config) {
profile.Region = region
}

if stsRegionalEndpoints := os.Getenv("AWS_STS_REGIONAL_ENDPOINTS"); stsRegionalEndpoints != "" && profile.STSRegionalEndpoints == "" {
log.Printf("Using %q from AWS_STS_REGIONAL_ENDPOINTS", stsRegionalEndpoints)
profile.STSRegionalEndpoints = stsRegionalEndpoints
}

if mfaSerial := os.Getenv("AWS_MFA_SERIAL"); mfaSerial != "" && profile.MfaSerial == "" {
log.Printf("Using mfa_serial %q from AWS_MFA_SERIAL", mfaSerial)
profile.MfaSerial = mfaSerial
Expand Down Expand Up @@ -462,6 +471,9 @@ type Config struct {
// Region is the AWS region
Region string

// STSRegionalEndpoints sets STS endpoint resolution logic, must be "regional" or "legacy"
STSRegionalEndpoints string

// Mfa config
MfaSerial string
MfaToken string
Expand Down
7 changes: 4 additions & 3 deletions vault/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Role_Arn=arn:aws:iam::4451234513441615400570:role/aws_admin
mfa_Serial=arn:aws:iam::1234513441:mfa/blah
Region=us-east-1
duration_seconds=1200
sts_regional_endpoints=legacy
[profile testincludeprofile1]
region=us-east-1
Expand Down Expand Up @@ -103,7 +104,7 @@ func TestConfigParsingProfiles(t *testing.T) {
}{
{vault.ProfileSection{Name: "user2", Region: "us-east-1"}, true},
{vault.ProfileSection{Name: "withsource", SourceProfile: "user2", Region: "us-east-1"}, true},
{vault.ProfileSection{Name: "withMFA", MfaSerial: "arn:aws:iam::1234513441:mfa/blah", RoleARN: "arn:aws:iam::4451234513441615400570:role/aws_admin", Region: "us-east-1", DurationSeconds: 1200, SourceProfile: "user2"}, true},
{vault.ProfileSection{Name: "withMFA", MfaSerial: "arn:aws:iam::1234513441:mfa/blah", RoleARN: "arn:aws:iam::4451234513441615400570:role/aws_admin", Region: "us-east-1", DurationSeconds: 1200, SourceProfile: "user2", STSRegionalEndpoints: "legacy"}, true},
{vault.ProfileSection{Name: "nopenotthere"}, false},
}

Expand Down Expand Up @@ -157,7 +158,7 @@ func TestProfilesFromConfig(t *testing.T) {
{Name: "default", Region: "us-west-2"},
{Name: "user2", Region: "us-east-1"},
{Name: "withsource", Region: "us-east-1", SourceProfile: "user2"},
{Name: "withMFA", MfaSerial: "arn:aws:iam::1234513441:mfa/blah", RoleARN: "arn:aws:iam::4451234513441615400570:role/aws_admin", Region: "us-east-1", DurationSeconds: 1200, SourceProfile: "user2"},
{Name: "withMFA", MfaSerial: "arn:aws:iam::1234513441:mfa/blah", RoleARN: "arn:aws:iam::4451234513441615400570:role/aws_admin", Region: "us-east-1", DurationSeconds: 1200, SourceProfile: "user2", STSRegionalEndpoints: "legacy"},
{Name: "testincludeprofile1", Region: "us-east-1"},
{Name: "testincludeprofile2", IncludeProfile: "testincludeprofile1"},
}
Expand Down Expand Up @@ -191,7 +192,7 @@ func TestAddProfileToExistingConfig(t *testing.T) {
{Name: "default", Region: "us-west-2"},
{Name: "user2", Region: "us-east-1"},
{Name: "withsource", Region: "us-east-1", SourceProfile: "user2"},
{Name: "withMFA", MfaSerial: "arn:aws:iam::1234513441:mfa/blah", RoleARN: "arn:aws:iam::4451234513441615400570:role/aws_admin", Region: "us-east-1", DurationSeconds: 1200, SourceProfile: "user2"},
{Name: "withMFA", MfaSerial: "arn:aws:iam::1234513441:mfa/blah", RoleARN: "arn:aws:iam::4451234513441615400570:role/aws_admin", Region: "us-east-1", DurationSeconds: 1200, SourceProfile: "user2", STSRegionalEndpoints: "legacy"},
{Name: "testincludeprofile1", Region: "us-east-1"},
{Name: "testincludeprofile2", IncludeProfile: "testincludeprofile1"},
{Name: "llamas", MfaSerial: "testserial", Region: "us-east-1", SourceProfile: "default"},
Expand Down
2 changes: 2 additions & 0 deletions vault/sessiontokenprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ func (p *SessionTokenProvider) GetSessionToken() (*sts.Credentials, error) {
}
}

log.Printf("Using STS endpoint %s", p.StsClient.Endpoint)

resp, err := p.StsClient.GetSessionToken(input)
if err != nil {
return nil, err
Expand Down
28 changes: 19 additions & 9 deletions vault/vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/99designs/keyring"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sso"
"github.com/aws/aws-sdk-go/service/ssooidc"
Expand All @@ -27,15 +28,23 @@ func init() {

var UseSessionCache = true

func NewSession(region string) (*session.Session, error) {
func NewSession(region, stsRegionalEndpoints string) (*session.Session, error) {
endpointConfig, err := endpoints.GetSTSRegionalEndpoint(stsRegionalEndpoints)
if err != nil && stsRegionalEndpoints != "" {
return nil, err
}

return session.NewSessionWithOptions(session.Options{
Config: aws.Config{Region: aws.String(region)},
Config: aws.Config{
Region: aws.String(region),
STSRegionalEndpoint: endpointConfig,
},
SharedConfigState: session.SharedConfigDisable,
})
}

func NewSessionWithCreds(creds *credentials.Credentials, region string) (*session.Session, error) {
s, err := NewSession(region)
func NewSessionWithCreds(creds *credentials.Credentials, region, stsRegionalEndpoints string) (*session.Session, error) {
s, err := NewSession(region, stsRegionalEndpoints)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -80,7 +89,7 @@ func NewMasterCredentials(k *CredentialKeyring, credentialsName string) *credent
}

func NewSessionTokenProvider(creds *credentials.Credentials, k keyring.Keyring, config *Config) (credentials.Provider, error) {
sess, err := NewSessionWithCreds(creds, config.Region)
sess, err := NewSessionWithCreds(creds, config.Region, config.STSRegionalEndpoints)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -114,7 +123,7 @@ func NewSessionTokenProvider(creds *credentials.Credentials, k keyring.Keyring,

// NewAssumeRoleProvider returns a provider that generates credentials using AssumeRole
func NewAssumeRoleProvider(creds *credentials.Credentials, k keyring.Keyring, config *Config) (credentials.Provider, error) {
sess, err := NewSessionWithCreds(creds, config.Region)
sess, err := NewSessionWithCreds(creds, config.Region, config.STSRegionalEndpoints)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -152,7 +161,7 @@ func NewAssumeRoleProvider(creds *credentials.Credentials, k keyring.Keyring, co
// NewAssumeRoleWithWebIdentityProvider returns a provider that generates
// credentials using AssumeRoleWithWebIdentity
func NewAssumeRoleWithWebIdentityProvider(k keyring.Keyring, config *Config) (credentials.Provider, error) {
sess, err := NewSession(config.Region)
sess, err := NewSession(config.Region, config.STSRegionalEndpoints)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -184,7 +193,7 @@ func NewAssumeRoleWithWebIdentityProvider(k keyring.Keyring, config *Config) (cr

// NewSSORoleCredentialsProvider creates a provider for SSO credentials
func NewSSORoleCredentialsProvider(k keyring.Keyring, config *Config) (credentials.Provider, error) {
sess, err := NewSession(config.SSORegion)
sess, err := NewSession(config.SSORegion, config.STSRegionalEndpoints)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -305,7 +314,8 @@ func NewFederationTokenCredentials(profileName string, k *CredentialKeyring, con
return nil, err
}

sess, err := NewSessionWithCreds(NewMasterCredentials(k, credentialsName), config.Region)
masterCreds := NewMasterCredentials(k, credentialsName)
sess, err := NewSessionWithCreds(masterCreds, config.Region, config.STSRegionalEndpoints)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 1afe60d

Please sign in to comment.