From cf78af3fa3689d151c03eba1c2a78e4a6e900b90 Mon Sep 17 00:00:00 2001 From: Michael Tibben Date: Sat, 4 Mar 2023 11:16:31 +1100 Subject: [PATCH] Refactor for clarity --- cli/exec.go | 6 +-- cli/export.go | 4 +- cli/login.go | 4 +- cli/rotate.go | 8 ++-- server/ecsserver.go | 4 +- vault/config.go | 58 +++++++++++------------ vault/config_test.go | 22 ++++----- vault/mfa.go | 2 +- vault/vault.go | 106 +++++++++++++++++++++---------------------- 9 files changed, 107 insertions(+), 107 deletions(-) diff --git a/cli/exec.go b/cli/exec.go index 3a15b3f76..0e354e770 100644 --- a/cli/exec.go +++ b/cli/exec.go @@ -29,7 +29,7 @@ type ExecCommandInput struct { StartEcsServer bool Lazy bool JSONDeprecated bool - Config vault.Config + Config vault.ProfileConfig SessionDuration time.Duration NoSession bool UseStdout bool @@ -167,7 +167,7 @@ func ExecCommand(input ExecCommandInput, f *vault.ConfigFile, keyring keyring.Ke return 0, err } - config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).LoadFromProfile(input.ProfileName) + config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).GetProfileConfig(input.ProfileName) if err != nil { return 0, fmt.Errorf("Error loading config: %w", err) } @@ -260,7 +260,7 @@ func createEnv(profileName string, region string) environ { return env } -func startEcsServerAndSetEnv(credsProvider aws.CredentialsProvider, config *vault.Config, lazy bool, cmdEnv *environ) error { +func startEcsServerAndSetEnv(credsProvider aws.CredentialsProvider, config *vault.ProfileConfig, lazy bool, cmdEnv *environ) error { ecsServer, err := server.NewEcsServer(context.TODO(), credsProvider, config, "", 0, lazy) if err != nil { return err diff --git a/cli/export.go b/cli/export.go index c09409c76..2f31ef445 100644 --- a/cli/export.go +++ b/cli/export.go @@ -19,7 +19,7 @@ import ( type ExportCommandInput struct { ProfileName string Format string - Config vault.Config + Config vault.ProfileConfig SessionDuration time.Duration NoSession bool UseStdout bool @@ -90,7 +90,7 @@ func ExportCommand(input ExportCommandInput, f *vault.ConfigFile, keyring keyrin return fmt.Errorf("in an existing aws-vault subshell; 'exit' from the subshell or unset AWS_VAULT to force") } - config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).LoadFromProfile(input.ProfileName) + config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).GetProfileConfig(input.ProfileName) if err != nil { return fmt.Errorf("Error loading config: %w", err) } diff --git a/cli/login.go b/cli/login.go index c2bc16bd9..a9006ca6e 100644 --- a/cli/login.go +++ b/cli/login.go @@ -24,7 +24,7 @@ type LoginCommandInput struct { ProfileName string UseStdout bool Path string - Config vault.Config + Config vault.ProfileConfig SessionDuration time.Duration NoSession bool } @@ -81,7 +81,7 @@ func ConfigureLoginCommand(app *kingpin.Application, a *AwsVault) { } func LoginCommand(input LoginCommandInput, f *vault.ConfigFile, keyring keyring.Keyring) error { - config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).LoadFromProfile(input.ProfileName) + config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).GetProfileConfig(input.ProfileName) if err != nil { return fmt.Errorf("Error loading config: %w", err) } diff --git a/cli/rotate.go b/cli/rotate.go index 23752e908..bf2384fca 100644 --- a/cli/rotate.go +++ b/cli/rotate.go @@ -16,7 +16,7 @@ import ( type RotateCommandInput struct { NoSession bool ProfileName string - Config vault.Config + Config vault.ProfileConfig } func ConfigureRotateCommand(app *kingpin.Application, a *AwsVault) { @@ -55,7 +55,7 @@ func RotateCommand(input RotateCommandInput, f *vault.ConfigFile, keyring keyrin vault.UseSessionCache = false configLoader := vault.NewConfigLoader(input.Config, f, input.ProfileName) - config, err := configLoader.LoadFromProfile(input.ProfileName) + config, err := configLoader.GetProfileConfig(input.ProfileName) if err != nil { return fmt.Errorf("Error loading config: %w", err) } @@ -170,7 +170,7 @@ func retry(maxTime time.Duration, sleep time.Duration, f func() error) (err erro } } -func getUsernameIfAssumingRole(ctx context.Context, awsCfg aws.Config, config *vault.Config) (*string, error) { +func getUsernameIfAssumingRole(ctx context.Context, awsCfg aws.Config, config *vault.ProfileConfig) (*string, error) { if config.RoleARN != "" { n, err := vault.GetUsernameFromSession(ctx, awsCfg) if err != nil { @@ -185,7 +185,7 @@ func getUsernameIfAssumingRole(ctx context.Context, awsCfg aws.Config, config *v func getProfilesInChain(profileName string, configLoader *vault.ConfigLoader) (profileNames []string, err error) { profileNames = append(profileNames, profileName) - config, err := configLoader.LoadFromProfile(profileName) + config, err := configLoader.GetProfileConfig(profileName) if err != nil { return profileNames, err } diff --git a/server/ecsserver.go b/server/ecsserver.go index cb34bfcb9..488d3f767 100644 --- a/server/ecsserver.go +++ b/server/ecsserver.go @@ -63,10 +63,10 @@ type EcsServer struct { server http.Server cache sync.Map baseCredsProvider aws.CredentialsProvider - config *vault.Config + config *vault.ProfileConfig } -func NewEcsServer(ctx context.Context, baseCredsProvider aws.CredentialsProvider, config *vault.Config, authToken string, port int, lazyLoadBaseCreds bool) (*EcsServer, error) { +func NewEcsServer(ctx context.Context, baseCredsProvider aws.CredentialsProvider, config *vault.ProfileConfig, authToken string, port int, lazyLoadBaseCreds bool) (*EcsServer, error) { listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) if err != nil { return nil, err diff --git a/vault/config.go b/vault/config.go index 37b54064e..bc8c6eca4 100644 --- a/vault/config.go +++ b/vault/config.go @@ -269,13 +269,14 @@ func (c *ConfigFile) ProfileNames() []string { // ConfigLoader loads config from configfile and environment variables type ConfigLoader struct { - BaseConfig Config - File *ConfigFile - ActiveProfile string + BaseConfig ProfileConfig + File *ConfigFile + ActiveProfile string + visitedProfiles []string } -func NewConfigLoader(baseConfig Config, file *ConfigFile, activeProfile string) *ConfigLoader { +func NewConfigLoader(baseConfig ProfileConfig, file *ConfigFile, activeProfile string) *ConfigLoader { return &ConfigLoader{ BaseConfig: baseConfig, File: file, @@ -297,7 +298,7 @@ func (cl *ConfigLoader) resetLoopDetection() { cl.visitedProfiles = []string{} } -func (cl *ConfigLoader) populateFromDefaults(config *Config) { +func (cl *ConfigLoader) populateFromDefaults(config *ProfileConfig) { if config.AssumeRoleDuration == 0 { config.AssumeRoleDuration = DefaultSessionDuration } @@ -312,7 +313,7 @@ func (cl *ConfigLoader) populateFromDefaults(config *Config) { } } -func (cl *ConfigLoader) populateFromConfigFile(config *Config, profileName string) error { +func (cl *ConfigLoader) populateFromConfigFile(config *ProfileConfig, profileName string) error { if !cl.visitProfile(profileName) { return fmt.Errorf("Loop detected in config file for profile '%s'", profileName) } @@ -419,7 +420,7 @@ func (cl *ConfigLoader) populateFromConfigFile(config *Config, profileName strin return nil } -func (cl *ConfigLoader) populateFromEnv(profile *Config) { +func (cl *ConfigLoader) populateFromEnv(profile *ProfileConfig) { if region := os.Getenv("AWS_REGION"); region != "" && profile.Region == "" { log.Printf("Using region %q from AWS_REGION", region) profile.Region = region @@ -501,9 +502,9 @@ func (cl *ConfigLoader) populateFromEnv(profile *Config) { } } -func (cl *ConfigLoader) hydrateSourceConfig(config *Config) error { +func (cl *ConfigLoader) hydrateSourceConfig(config *ProfileConfig) error { if config.SourceProfileName != "" { - sc, err := cl.LoadFromProfile(config.SourceProfileName) + sc, err := cl.GetProfileConfig(config.SourceProfileName) if err != nil { return err } @@ -513,8 +514,8 @@ func (cl *ConfigLoader) hydrateSourceConfig(config *Config) error { return nil } -// LoadFromProfile loads the profile from the config file and environment variables into config -func (cl *ConfigLoader) LoadFromProfile(profileName string) (*Config, error) { +// GetProfileConfig loads the profile from the config file and environment variables into config +func (cl *ConfigLoader) GetProfileConfig(profileName string) (*ProfileConfig, error) { config := cl.BaseConfig config.ProfileName = profileName cl.populateFromEnv(&config) @@ -535,8 +536,8 @@ func (cl *ConfigLoader) LoadFromProfile(profileName string) (*Config, error) { return &config, nil } -// Config is a collection of configuration options for creating temporary credentials -type Config struct { +// ProfileConfig is a collection of configuration options for creating temporary credentials +type ProfileConfig struct { // ProfileName specifies the name of the profile config ProfileName string @@ -544,10 +545,10 @@ type Config struct { SourceProfileName string // SourceProfile is the profile where credentials come from - SourceProfile *Config + SourceProfile *ProfileConfig - // ChainedFromProfile is the profile that used this profile as it's source profile - ChainedFromProfile *Config + // ChainedFromProfile is the profile that used this profile as its source profile + ChainedFromProfile *ProfileConfig // Region is the AWS region Region string @@ -619,7 +620,7 @@ type Config struct { } // SetSessionTags parses a comma separated key=vaue string and sets Config.SessionTags map -func (c *Config) SetSessionTags(s string) error { +func (c *ProfileConfig) SetSessionTags(s string) error { c.SessionTags = make(map[string]string) for _, tag := range strings.Split(s, ",") { kvPair := strings.SplitN(tag, "=", 2) @@ -633,7 +634,7 @@ func (c *Config) SetSessionTags(s string) error { } // SetTransitiveSessionTags parses a comma separated string and sets Config.TransitiveSessionTags -func (c *Config) SetTransitiveSessionTags(s string) { +func (c *ProfileConfig) SetTransitiveSessionTags(s string) { for _, tag := range strings.Split(s, ",") { if tag = strings.TrimSpace(tag); tag != "" { c.TransitiveSessionTags = append(c.TransitiveSessionTags, tag) @@ -641,46 +642,46 @@ func (c *Config) SetTransitiveSessionTags(s string) { } } -func (c *Config) IsChained() bool { +func (c *ProfileConfig) IsChained() bool { return c.ChainedFromProfile != nil } -func (c *Config) HasSourceProfile() bool { +func (c *ProfileConfig) HasSourceProfile() bool { return c.SourceProfile != nil } -func (c *Config) HasMfaSerial() bool { +func (c *ProfileConfig) HasMfaSerial() bool { return c.MfaSerial != "" } -func (c *Config) HasRole() bool { +func (c *ProfileConfig) HasRole() bool { return c.RoleARN != "" } -func (c *Config) HasSSOSession() bool { +func (c *ProfileConfig) HasSSOSession() bool { return c.SSOSession != "" } -func (c *Config) HasSSOStartURL() bool { +func (c *ProfileConfig) HasSSOStartURL() bool { return c.SSOStartURL != "" } -func (c *Config) HasWebIdentity() bool { +func (c *ProfileConfig) HasWebIdentity() bool { return c.WebIdentityTokenFile != "" || c.WebIdentityTokenProcess != "" } -func (c *Config) HasCredentialProcess() bool { +func (c *ProfileConfig) HasCredentialProcess() bool { return c.CredentialProcess != "" } -func (c *Config) GetSessionTokenDuration() time.Duration { +func (c *ProfileConfig) GetSessionTokenDuration() time.Duration { if c.IsChained() { return c.ChainedGetSessionTokenDuration } return c.NonChainedGetSessionTokenDuration } -func (c *Config) Validate() error { +func (c *ProfileConfig) Validate() error { if c.HasSSOSession() && !c.HasSSOStartURL() { return fmt.Errorf("profile '%s' has sso_session but no sso_start_url", c.ProfileName) } @@ -700,7 +701,6 @@ func (c *Config) Validate() error { } else if c.HasRole() { n++ } - if n > 1 { return fmt.Errorf("profile '%s' has more than one source of credentials", c.ProfileName) } diff --git a/vault/config_test.go b/vault/config_test.go index 0dd05ba35..3e3644ab0 100644 --- a/vault/config_test.go +++ b/vault/config_test.go @@ -254,7 +254,7 @@ func TestIncludeProfile(t *testing.T) { } configLoader := &vault.ConfigLoader{File: configFile} - config, err := configLoader.LoadFromProfile("testincludeprofile2") + config, err := configLoader.GetProfileConfig("testincludeprofile2") if err != nil { t.Fatalf("Should have found a profile: %v", err) } @@ -274,7 +274,7 @@ func TestIncludeSsoSession(t *testing.T) { } configLoader := &vault.ConfigLoader{File: configFile} - config, err := configLoader.LoadFromProfile("with-sso-session") + config, err := configLoader.GetProfileConfig("with-sso-session") if err != nil { t.Fatalf("Should have found a profile: %v", err) } @@ -369,7 +369,7 @@ source_profile=foo } configLoader := &vault.ConfigLoader{File: configFile} - config, err := configLoader.LoadFromProfile("foo") + config, err := configLoader.GetProfileConfig("foo") if err != nil { t.Fatalf("Should have found a profile: %v", err) } @@ -406,7 +406,7 @@ source_profile=root } configLoader := &vault.ConfigLoader{File: configFile} - config, err := configLoader.LoadFromProfile("foo") + config, err := configLoader.GetProfileConfig("foo") if err != nil { t.Fatalf("Should have found a profile: %v", err) } @@ -441,7 +441,7 @@ func TestSetSessionTags(t *testing.T) { } for _, tc := range testCases { - config := vault.Config{} + config := vault.ProfileConfig{} err := config.SetSessionTags(tc.stringValue) if tc.ok { if err != nil { @@ -473,7 +473,7 @@ func TestSetTransitiveSessionTags(t *testing.T) { } for _, tc := range testCases { - config := vault.Config{} + config := vault.ProfileConfig{} config.SetTransitiveSessionTags(tc.stringValue) if !reflect.DeepEqual(tc.expected, config.TransitiveSessionTags) { t.Fatalf("Expected TransitiveSessionTags: %+v, got %+v", tc.expected, config.TransitiveSessionTags) @@ -496,7 +496,7 @@ transitive_session_tags = tagOne ,tagTwo,tagThree t.Fatal(err) } configLoader := &vault.ConfigLoader{File: configFile, ActiveProfile: "tagged"} - config, err := configLoader.LoadFromProfile("tagged") + config, err := configLoader.GetProfileConfig("tagged") if err != nil { t.Fatalf("Should have found a profile: %v", err) } @@ -533,7 +533,7 @@ transitive_session_tags = tagOne ,tagTwo,tagThree t.Fatal(err) } configLoader := &vault.ConfigLoader{File: configFile, ActiveProfile: "tagged"} - config, err := configLoader.LoadFromProfile("tagged") + config, err := configLoader.GetProfileConfig("tagged") if err != nil { t.Fatalf("Should have found a profile: %v", err) } @@ -578,7 +578,7 @@ source_profile = interim t.Fatal(err) } configLoader := &vault.ConfigLoader{File: configFile, ActiveProfile: "target"} - config, err := configLoader.LoadFromProfile("target") + config, err := configLoader.GetProfileConfig("target") if err != nil { t.Fatalf("Should have found a profile: %v", err) } @@ -640,13 +640,13 @@ credential_process = true configFile, _ := vault.LoadConfig(f) configLoader := &vault.ConfigLoader{File: configFile} - config, _ := configLoader.LoadFromProfile("foo:staging") + config, _ := configLoader.GetProfileConfig("foo:staging") err := config.Validate() if err != nil { t.Fatalf("Should have validated: %v", err) } - config, _ = configLoader.LoadFromProfile("foo:production") + config, _ = configLoader.GetProfileConfig("foo:production") err = config.Validate() if err == nil { t.Fatalf("Should have failed validation: %v", err) diff --git a/vault/mfa.go b/vault/mfa.go index c5b897edf..1b1313840 100644 --- a/vault/mfa.go +++ b/vault/mfa.go @@ -33,7 +33,7 @@ func (m *Mfa) GetMfaSerial() string { return m.mfaSerial } -func NewMfa(config *Config) *Mfa { +func NewMfa(config *ProfileConfig) *Mfa { m := Mfa{ mfaSerial: config.MfaSerial, } diff --git a/vault/vault.go b/vault/vault.go index 17ddad135..1f9ff7722 100644 --- a/vault/vault.go +++ b/vault/vault.go @@ -48,7 +48,7 @@ func NewMasterCredentialsProvider(k *CredentialKeyring, credentialsName string) return &KeyringProvider{k, credentialsName} } -func NewSessionTokenProvider(credsProvider aws.CredentialsProvider, k keyring.Keyring, config *Config) (aws.CredentialsProvider, error) { +func NewSessionTokenProvider(credsProvider aws.CredentialsProvider, k keyring.Keyring, config *ProfileConfig) (aws.CredentialsProvider, error) { cfg := NewAwsConfigWithCredsProvider(credsProvider, config.Region, config.STSRegionalEndpoints) sessionTokenProvider := &SessionTokenProvider{ @@ -74,7 +74,7 @@ func NewSessionTokenProvider(credsProvider aws.CredentialsProvider, k keyring.Ke } // NewAssumeRoleProvider returns a provider that generates credentials using AssumeRole -func NewAssumeRoleProvider(credsProvider aws.CredentialsProvider, k keyring.Keyring, config *Config) (aws.CredentialsProvider, error) { +func NewAssumeRoleProvider(credsProvider aws.CredentialsProvider, k keyring.Keyring, config *ProfileConfig) (aws.CredentialsProvider, error) { cfg := NewAwsConfigWithCredsProvider(credsProvider, config.Region, config.STSRegionalEndpoints) p := &AssumeRoleProvider{ @@ -107,7 +107,7 @@ func NewAssumeRoleProvider(credsProvider aws.CredentialsProvider, k keyring.Keyr // NewAssumeRoleWithWebIdentityProvider returns a provider that generates // credentials using AssumeRoleWithWebIdentity -func NewAssumeRoleWithWebIdentityProvider(k keyring.Keyring, config *Config) (aws.CredentialsProvider, error) { +func NewAssumeRoleWithWebIdentityProvider(k keyring.Keyring, config *ProfileConfig) (aws.CredentialsProvider, error) { cfg := NewAwsConfig(config.Region, config.STSRegionalEndpoints) p := &AssumeRoleWithWebIdentityProvider{ @@ -135,7 +135,7 @@ func NewAssumeRoleWithWebIdentityProvider(k keyring.Keyring, config *Config) (aw } // NewSSORoleCredentialsProvider creates a provider for SSO credentials -func NewSSORoleCredentialsProvider(k keyring.Keyring, config *Config) (aws.CredentialsProvider, error) { +func NewSSORoleCredentialsProvider(k keyring.Keyring, config *ProfileConfig) (aws.CredentialsProvider, error) { cfg := NewAwsConfig(config.SSORegion, config.STSRegionalEndpoints) ssoRoleCredentialsProvider := &SSORoleCredentialsProvider{ @@ -166,7 +166,7 @@ func NewSSORoleCredentialsProvider(k keyring.Keyring, config *Config) (aws.Crede // NewCredentialProcessProvider creates a provider to retrieve credentials from an external // executable as described in https://docs.aws.amazon.com/cli/latest/topic/config-vars.html#sourcing-credentials-from-external-processes -func NewCredentialProcessProvider(k keyring.Keyring, config *Config) (aws.CredentialsProvider, error) { +func NewCredentialProcessProvider(k keyring.Keyring, config *ProfileConfig) (aws.CredentialsProvider, error) { credentialProcessProvider := &CredentialProcessProvider{ CredentialProcess: config.CredentialProcess, } @@ -186,6 +186,49 @@ func NewCredentialProcessProvider(k keyring.Keyring, config *Config) (aws.Creden return credentialProcessProvider, nil } +func NewFederationTokenCredentialsProvider(ctx context.Context, profileName string, k *CredentialKeyring, config *ProfileConfig) (aws.CredentialsProvider, error) { + credentialsName, err := FindMasterCredentialsNameFor(profileName, k, config) + if err != nil { + return nil, err + } + masterCreds := NewMasterCredentialsProvider(k, credentialsName) + + return NewFederationTokenProvider(ctx, masterCreds, config) +} + +func NewFederationTokenProvider(ctx context.Context, credsProvider aws.CredentialsProvider, config *ProfileConfig) (*FederationTokenProvider, error) { + cfg := NewAwsConfigWithCredsProvider(credsProvider, config.Region, config.STSRegionalEndpoints) + + currentUsername, err := GetUsernameFromSession(ctx, cfg) + if err != nil { + return nil, err + } + + log.Printf("Using GetFederationToken for credentials") + return &FederationTokenProvider{ + StsClient: sts.NewFromConfig(cfg), + Name: currentUsername, + Duration: config.GetFederationTokenDuration, + }, nil +} + +func FindMasterCredentialsNameFor(profileName string, keyring *CredentialKeyring, config *ProfileConfig) (string, error) { + hasMasterCreds, err := keyring.Has(profileName) + if err != nil { + return "", err + } + + if hasMasterCreds { + return profileName, nil + } + + if profileName == config.SourceProfileName { + return "", fmt.Errorf("No master credentials found") + } + + return FindMasterCredentialsNameFor(config.SourceProfileName, keyring, config) +} + type tempCredsCreator struct { // UseSession will disable the use of GetSessionToken when set to false UseSession bool @@ -194,7 +237,7 @@ type tempCredsCreator struct { chainedMfa string } -func (t *tempCredsCreator) getSourceCreds(config *Config) (sourcecredsProvider aws.CredentialsProvider, err error) { +func (t *tempCredsCreator) getSourceCreds(config *ProfileConfig) (sourcecredsProvider aws.CredentialsProvider, err error) { if config.HasSourceProfile() { log.Printf("profile %s: sourcing credentials from profile %s", config.ProfileName, config.SourceProfile.ProfileName) return t.GetProviderForProfile(config.SourceProfile) @@ -213,7 +256,7 @@ func (t *tempCredsCreator) getSourceCreds(config *Config) (sourcecredsProvider a return nil, fmt.Errorf("profile %s: credentials missing", config.ProfileName) } -func (t *tempCredsCreator) GetProviderForProfile(config *Config) (aws.CredentialsProvider, error) { +func (t *tempCredsCreator) GetProviderForProfile(config *ProfileConfig) (aws.CredentialsProvider, error) { if err := config.Validate(); err != nil { return nil, err } @@ -259,7 +302,7 @@ func (t *tempCredsCreator) GetProviderForProfile(config *Config) (aws.Credential } // canUseGetSessionToken determines if GetSessionToken should be used, and if not returns a reason -func (t *tempCredsCreator) canUseGetSessionToken(c *Config) (bool, string) { +func (t *tempCredsCreator) canUseGetSessionToken(c *ProfileConfig) (bool, string) { if !t.UseSession { return false, "sessions are disabled" } @@ -285,7 +328,7 @@ func (t *tempCredsCreator) canUseGetSessionToken(c *Config) (bool, string) { return true, "" } -func mfaDetails(mfaChained bool, config *Config) string { +func mfaDetails(mfaChained bool, config *ProfileConfig) string { if mfaChained { return "(chained MFA)" } @@ -296,53 +339,10 @@ func mfaDetails(mfaChained bool, config *Config) string { } // NewTempCredentialsProvider creates a credential provider for the given config -func NewTempCredentialsProvider(config *Config, keyring *CredentialKeyring, useSession bool) (aws.CredentialsProvider, error) { +func NewTempCredentialsProvider(config *ProfileConfig, keyring *CredentialKeyring, useSession bool) (aws.CredentialsProvider, error) { t := tempCredsCreator{ Keyring: keyring, UseSession: useSession, } return t.GetProviderForProfile(config) } - -func NewFederationTokenCredentialsProvider(ctx context.Context, profileName string, k *CredentialKeyring, config *Config) (aws.CredentialsProvider, error) { - credentialsName, err := FindMasterCredentialsNameFor(profileName, k, config) - if err != nil { - return nil, err - } - masterCreds := NewMasterCredentialsProvider(k, credentialsName) - - return NewFederationTokenProvider(ctx, masterCreds, config) -} - -func NewFederationTokenProvider(ctx context.Context, credsProvider aws.CredentialsProvider, config *Config) (*FederationTokenProvider, error) { - cfg := NewAwsConfigWithCredsProvider(credsProvider, config.Region, config.STSRegionalEndpoints) - - currentUsername, err := GetUsernameFromSession(ctx, cfg) - if err != nil { - return nil, err - } - - log.Printf("Using GetFederationToken for credentials") - return &FederationTokenProvider{ - StsClient: sts.NewFromConfig(cfg), - Name: currentUsername, - Duration: config.GetFederationTokenDuration, - }, nil -} - -func FindMasterCredentialsNameFor(profileName string, keyring *CredentialKeyring, config *Config) (string, error) { - hasMasterCreds, err := keyring.Has(profileName) - if err != nil { - return "", err - } - - if hasMasterCreds { - return profileName, nil - } - - if profileName == config.SourceProfileName { - return "", fmt.Errorf("No master credentials found") - } - - return FindMasterCredentialsNameFor(config.SourceProfileName, keyring, config) -}