Skip to content

Commit

Permalink
Merge pull request #383 from jstewmon/thread-mfa-serial
Browse files Browse the repository at this point in the history
thread any found MFASerial value through recurive Retrieve calls
  • Loading branch information
mtibben committed Jun 23, 2019
2 parents 4a34d94 + f09d629 commit 935affd
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 13 deletions.
11 changes: 4 additions & 7 deletions cli/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func LoginCommand(app *kingpin.Application, input LoginCommandInput) {
noSession = true
}

provider, err := vault.NewVaultProvider(input.Keyring, input.Profile, vault.VaultOptions{
creds, err := vault.NewVaultCredentials(input.Keyring, input.Profile, vault.VaultOptions{
AssumeRoleDuration: input.AssumeRoleDuration,
MfaToken: input.MfaToken,
MfaPrompt: input.MfaPrompt,
Expand All @@ -102,11 +102,8 @@ func LoginCommand(app *kingpin.Application, input LoginCommandInput) {
Region: profile.Region,
})
if err != nil {
app.Fatalf("Failed to create vault provider: %v", err)
return
app.Fatalf("%v", err)
}

creds := credentials.NewCredentials(provider)
val, err := creds.Get()
if err != nil {
app.Fatalf(awsConfig.FormatCredentialError(err, input.Profile))
Expand All @@ -118,7 +115,7 @@ func LoginCommand(app *kingpin.Application, input LoginCommandInput) {
// if AssumeRole isn't used, GetFederationToken has to be used for IAM credentials
if val.SessionToken == "" {
log.Printf("No session token found, calling GetFederationToken")
stsCreds, err := getFederationToken(val, input.FederationTokenDuration, provider.Region)
stsCreds, err := getFederationToken(val, input.FederationTokenDuration, profile.Region)
if err != nil {
app.Fatalf("Failed to call GetFederationToken: %v\n"+
"Login for non-assumed roles depends on permission to call sts:GetFederationToken", err)
Expand All @@ -141,7 +138,7 @@ func LoginCommand(app *kingpin.Application, input LoginCommandInput) {
return
}

loginURLPrefix, destination := generateLoginURL(provider.Region, input.Path)
loginURLPrefix, destination := generateLoginURL(profile.Region, input.Path)

req, err := http.NewRequest("GET", loginURLPrefix, nil)
if err != nil {
Expand Down
43 changes: 37 additions & 6 deletions vault/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type VaultOptions struct {
SessionDuration time.Duration
AssumeRoleDuration time.Duration
ExpiryWindow time.Duration
MfaSerial string
MfaToken string
MfaPrompt prompt.PromptFunc
NoSession bool
Expand Down Expand Up @@ -266,10 +267,10 @@ func (p *VaultProvider) getSessionToken(creds *credentials.Value) (sts.Credentia
DurationSeconds: aws.Int64(int64(p.SessionDuration.Seconds())),
}

if profile, _ := p.Config.Profile(p.profile); profile.MFASerial != "" {
params.SerialNumber = aws.String(profile.MFASerial)
if p.VaultOptions.MfaSerial != "" {
params.SerialNumber = aws.String(p.VaultOptions.MfaSerial)
if p.MfaToken == "" {
token, err := p.MfaPrompt(fmt.Sprintf("Enter token for %s: ", profile.MFASerial))
token, err := p.MfaPrompt(fmt.Sprintf("Enter token for %s: ", p.VaultOptions.MfaSerial))
if err != nil {
return sts.Credentials{}, err
}
Expand Down Expand Up @@ -349,10 +350,10 @@ func (p *VaultProvider) assumeRole(creds credentials.Value, profile Profile) (st
}

// if we don't have a session, we need to include MFA token in the AssumeRole call
if profile.MFASerial != "" {
input.SerialNumber = aws.String(profile.MFASerial)
if p.VaultOptions.MfaSerial != "" {
input.SerialNumber = aws.String(p.VaultOptions.MfaSerial)
if p.MfaToken == "" {
token, err := p.MfaPrompt(fmt.Sprintf("Enter token for %s: ", profile.MFASerial))
token, err := p.MfaPrompt(fmt.Sprintf("Enter token for %s: ", p.VaultOptions.MfaSerial))
if err != nil {
return sts.Credentials{}, err
}
Expand Down Expand Up @@ -420,6 +421,14 @@ type VaultCredentials struct {
}

func NewVaultCredentials(k keyring.Keyring, profile string, opts VaultOptions) (*VaultCredentials, error) {
// always get the list of profiles for cycle detection
profiles, err := profileChain(profile, opts.Config)
if err != nil {
return nil, err
}
if len(opts.MfaSerial) == 0 {
opts.MfaSerial = findMfaSerial(profiles)
}
provider, err := NewVaultProvider(k, profile, opts)
if err != nil {
return nil, err
Expand All @@ -431,3 +440,25 @@ func NewVaultCredentials(k keyring.Keyring, profile string, opts VaultOptions) (
func (v *VaultCredentials) Expires() time.Time {
return v.provider.expires
}

func findMfaSerial(profiles []Profile) string {
for _, profile := range profiles {
if len(profile.MFASerial) > 0 {
return profile.MFASerial
}
}
return ""
}

func profileChain(profile string, config *Config) ([]Profile, error) {
visited := map[string]bool{}
var profiles []Profile
for configProfile, exists := config.Profile(profile); exists; configProfile, exists = config.Profile(configProfile.SourceProfile) {
if _, ok := visited[configProfile.Name]; ok {
return nil, fmt.Errorf("source profile cycle detected for profile %v", profile)
}
visited[configProfile.Name] = true
profiles = append(profiles, configProfile)
}
return profiles, nil
}

0 comments on commit 935affd

Please sign in to comment.