Skip to content

Commit

Permalink
Refactor for clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
mtibben committed Mar 4, 2023
1 parent f2527e2 commit cf78af3
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 107 deletions.
6 changes: 3 additions & 3 deletions cli/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type ExecCommandInput struct {
StartEcsServer bool
Lazy bool
JSONDeprecated bool
Config vault.Config
Config vault.ProfileConfig
SessionDuration time.Duration
NoSession bool
UseStdout bool
Expand Down Expand Up @@ -167,7 +167,7 @@ func ExecCommand(input ExecCommandInput, f *vault.ConfigFile, keyring keyring.Ke
return 0, err
}

config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).LoadFromProfile(input.ProfileName)
config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).GetProfileConfig(input.ProfileName)
if err != nil {
return 0, fmt.Errorf("Error loading config: %w", err)
}
Expand Down Expand Up @@ -260,7 +260,7 @@ func createEnv(profileName string, region string) environ {
return env
}

func startEcsServerAndSetEnv(credsProvider aws.CredentialsProvider, config *vault.Config, lazy bool, cmdEnv *environ) error {
func startEcsServerAndSetEnv(credsProvider aws.CredentialsProvider, config *vault.ProfileConfig, lazy bool, cmdEnv *environ) error {
ecsServer, err := server.NewEcsServer(context.TODO(), credsProvider, config, "", 0, lazy)
if err != nil {
return err
Expand Down
4 changes: 2 additions & 2 deletions cli/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
type ExportCommandInput struct {
ProfileName string
Format string
Config vault.Config
Config vault.ProfileConfig
SessionDuration time.Duration
NoSession bool
UseStdout bool
Expand Down Expand Up @@ -90,7 +90,7 @@ func ExportCommand(input ExportCommandInput, f *vault.ConfigFile, keyring keyrin
return fmt.Errorf("in an existing aws-vault subshell; 'exit' from the subshell or unset AWS_VAULT to force")
}

config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).LoadFromProfile(input.ProfileName)
config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).GetProfileConfig(input.ProfileName)
if err != nil {
return fmt.Errorf("Error loading config: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions cli/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type LoginCommandInput struct {
ProfileName string
UseStdout bool
Path string
Config vault.Config
Config vault.ProfileConfig
SessionDuration time.Duration
NoSession bool
}
Expand Down Expand Up @@ -81,7 +81,7 @@ func ConfigureLoginCommand(app *kingpin.Application, a *AwsVault) {
}

func LoginCommand(input LoginCommandInput, f *vault.ConfigFile, keyring keyring.Keyring) error {
config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).LoadFromProfile(input.ProfileName)
config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).GetProfileConfig(input.ProfileName)
if err != nil {
return fmt.Errorf("Error loading config: %w", err)
}
Expand Down
8 changes: 4 additions & 4 deletions cli/rotate.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
type RotateCommandInput struct {
NoSession bool
ProfileName string
Config vault.Config
Config vault.ProfileConfig
}

func ConfigureRotateCommand(app *kingpin.Application, a *AwsVault) {
Expand Down Expand Up @@ -55,7 +55,7 @@ func RotateCommand(input RotateCommandInput, f *vault.ConfigFile, keyring keyrin
vault.UseSessionCache = false

configLoader := vault.NewConfigLoader(input.Config, f, input.ProfileName)
config, err := configLoader.LoadFromProfile(input.ProfileName)
config, err := configLoader.GetProfileConfig(input.ProfileName)
if err != nil {
return fmt.Errorf("Error loading config: %w", err)
}
Expand Down Expand Up @@ -170,7 +170,7 @@ func retry(maxTime time.Duration, sleep time.Duration, f func() error) (err erro
}
}

func getUsernameIfAssumingRole(ctx context.Context, awsCfg aws.Config, config *vault.Config) (*string, error) {
func getUsernameIfAssumingRole(ctx context.Context, awsCfg aws.Config, config *vault.ProfileConfig) (*string, error) {
if config.RoleARN != "" {
n, err := vault.GetUsernameFromSession(ctx, awsCfg)
if err != nil {
Expand All @@ -185,7 +185,7 @@ func getUsernameIfAssumingRole(ctx context.Context, awsCfg aws.Config, config *v
func getProfilesInChain(profileName string, configLoader *vault.ConfigLoader) (profileNames []string, err error) {
profileNames = append(profileNames, profileName)

config, err := configLoader.LoadFromProfile(profileName)
config, err := configLoader.GetProfileConfig(profileName)
if err != nil {
return profileNames, err
}
Expand Down
4 changes: 2 additions & 2 deletions server/ecsserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ type EcsServer struct {
server http.Server
cache sync.Map
baseCredsProvider aws.CredentialsProvider
config *vault.Config
config *vault.ProfileConfig
}

func NewEcsServer(ctx context.Context, baseCredsProvider aws.CredentialsProvider, config *vault.Config, authToken string, port int, lazyLoadBaseCreds bool) (*EcsServer, error) {
func NewEcsServer(ctx context.Context, baseCredsProvider aws.CredentialsProvider, config *vault.ProfileConfig, authToken string, port int, lazyLoadBaseCreds bool) (*EcsServer, error) {
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
if err != nil {
return nil, err
Expand Down
58 changes: 29 additions & 29 deletions vault/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,14 @@ func (c *ConfigFile) ProfileNames() []string {

// ConfigLoader loads config from configfile and environment variables
type ConfigLoader struct {
BaseConfig Config
File *ConfigFile
ActiveProfile string
BaseConfig ProfileConfig
File *ConfigFile
ActiveProfile string

visitedProfiles []string
}

func NewConfigLoader(baseConfig Config, file *ConfigFile, activeProfile string) *ConfigLoader {
func NewConfigLoader(baseConfig ProfileConfig, file *ConfigFile, activeProfile string) *ConfigLoader {
return &ConfigLoader{
BaseConfig: baseConfig,
File: file,
Expand All @@ -297,7 +298,7 @@ func (cl *ConfigLoader) resetLoopDetection() {
cl.visitedProfiles = []string{}
}

func (cl *ConfigLoader) populateFromDefaults(config *Config) {
func (cl *ConfigLoader) populateFromDefaults(config *ProfileConfig) {
if config.AssumeRoleDuration == 0 {
config.AssumeRoleDuration = DefaultSessionDuration
}
Expand All @@ -312,7 +313,7 @@ func (cl *ConfigLoader) populateFromDefaults(config *Config) {
}
}

func (cl *ConfigLoader) populateFromConfigFile(config *Config, profileName string) error {
func (cl *ConfigLoader) populateFromConfigFile(config *ProfileConfig, profileName string) error {
if !cl.visitProfile(profileName) {
return fmt.Errorf("Loop detected in config file for profile '%s'", profileName)
}
Expand Down Expand Up @@ -419,7 +420,7 @@ func (cl *ConfigLoader) populateFromConfigFile(config *Config, profileName strin
return nil
}

func (cl *ConfigLoader) populateFromEnv(profile *Config) {
func (cl *ConfigLoader) populateFromEnv(profile *ProfileConfig) {
if region := os.Getenv("AWS_REGION"); region != "" && profile.Region == "" {
log.Printf("Using region %q from AWS_REGION", region)
profile.Region = region
Expand Down Expand Up @@ -501,9 +502,9 @@ func (cl *ConfigLoader) populateFromEnv(profile *Config) {
}
}

func (cl *ConfigLoader) hydrateSourceConfig(config *Config) error {
func (cl *ConfigLoader) hydrateSourceConfig(config *ProfileConfig) error {
if config.SourceProfileName != "" {
sc, err := cl.LoadFromProfile(config.SourceProfileName)
sc, err := cl.GetProfileConfig(config.SourceProfileName)
if err != nil {
return err
}
Expand All @@ -513,8 +514,8 @@ func (cl *ConfigLoader) hydrateSourceConfig(config *Config) error {
return nil
}

// LoadFromProfile loads the profile from the config file and environment variables into config
func (cl *ConfigLoader) LoadFromProfile(profileName string) (*Config, error) {
// GetProfileConfig loads the profile from the config file and environment variables into config
func (cl *ConfigLoader) GetProfileConfig(profileName string) (*ProfileConfig, error) {
config := cl.BaseConfig
config.ProfileName = profileName
cl.populateFromEnv(&config)
Expand All @@ -535,19 +536,19 @@ func (cl *ConfigLoader) LoadFromProfile(profileName string) (*Config, error) {
return &config, nil
}

// Config is a collection of configuration options for creating temporary credentials
type Config struct {
// ProfileConfig is a collection of configuration options for creating temporary credentials
type ProfileConfig struct {
// ProfileName specifies the name of the profile config
ProfileName string

// SourceProfile is the profile where credentials come from
SourceProfileName string

// SourceProfile is the profile where credentials come from
SourceProfile *Config
SourceProfile *ProfileConfig

// ChainedFromProfile is the profile that used this profile as it's source profile
ChainedFromProfile *Config
// ChainedFromProfile is the profile that used this profile as its source profile
ChainedFromProfile *ProfileConfig

// Region is the AWS region
Region string
Expand Down Expand Up @@ -619,7 +620,7 @@ type Config struct {
}

// SetSessionTags parses a comma separated key=vaue string and sets Config.SessionTags map
func (c *Config) SetSessionTags(s string) error {
func (c *ProfileConfig) SetSessionTags(s string) error {
c.SessionTags = make(map[string]string)
for _, tag := range strings.Split(s, ",") {
kvPair := strings.SplitN(tag, "=", 2)
Expand All @@ -633,54 +634,54 @@ func (c *Config) SetSessionTags(s string) error {
}

// SetTransitiveSessionTags parses a comma separated string and sets Config.TransitiveSessionTags
func (c *Config) SetTransitiveSessionTags(s string) {
func (c *ProfileConfig) SetTransitiveSessionTags(s string) {
for _, tag := range strings.Split(s, ",") {
if tag = strings.TrimSpace(tag); tag != "" {
c.TransitiveSessionTags = append(c.TransitiveSessionTags, tag)
}
}
}

func (c *Config) IsChained() bool {
func (c *ProfileConfig) IsChained() bool {
return c.ChainedFromProfile != nil
}

func (c *Config) HasSourceProfile() bool {
func (c *ProfileConfig) HasSourceProfile() bool {
return c.SourceProfile != nil
}

func (c *Config) HasMfaSerial() bool {
func (c *ProfileConfig) HasMfaSerial() bool {
return c.MfaSerial != ""
}

func (c *Config) HasRole() bool {
func (c *ProfileConfig) HasRole() bool {
return c.RoleARN != ""
}

func (c *Config) HasSSOSession() bool {
func (c *ProfileConfig) HasSSOSession() bool {
return c.SSOSession != ""
}

func (c *Config) HasSSOStartURL() bool {
func (c *ProfileConfig) HasSSOStartURL() bool {
return c.SSOStartURL != ""
}

func (c *Config) HasWebIdentity() bool {
func (c *ProfileConfig) HasWebIdentity() bool {
return c.WebIdentityTokenFile != "" || c.WebIdentityTokenProcess != ""
}

func (c *Config) HasCredentialProcess() bool {
func (c *ProfileConfig) HasCredentialProcess() bool {
return c.CredentialProcess != ""
}

func (c *Config) GetSessionTokenDuration() time.Duration {
func (c *ProfileConfig) GetSessionTokenDuration() time.Duration {
if c.IsChained() {
return c.ChainedGetSessionTokenDuration
}
return c.NonChainedGetSessionTokenDuration
}

func (c *Config) Validate() error {
func (c *ProfileConfig) Validate() error {
if c.HasSSOSession() && !c.HasSSOStartURL() {
return fmt.Errorf("profile '%s' has sso_session but no sso_start_url", c.ProfileName)
}
Expand All @@ -700,7 +701,6 @@ func (c *Config) Validate() error {
} else if c.HasRole() {
n++
}

if n > 1 {
return fmt.Errorf("profile '%s' has more than one source of credentials", c.ProfileName)
}
Expand Down
22 changes: 11 additions & 11 deletions vault/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ func TestIncludeProfile(t *testing.T) {
}

configLoader := &vault.ConfigLoader{File: configFile}
config, err := configLoader.LoadFromProfile("testincludeprofile2")
config, err := configLoader.GetProfileConfig("testincludeprofile2")
if err != nil {
t.Fatalf("Should have found a profile: %v", err)
}
Expand All @@ -274,7 +274,7 @@ func TestIncludeSsoSession(t *testing.T) {
}

configLoader := &vault.ConfigLoader{File: configFile}
config, err := configLoader.LoadFromProfile("with-sso-session")
config, err := configLoader.GetProfileConfig("with-sso-session")
if err != nil {
t.Fatalf("Should have found a profile: %v", err)
}
Expand Down Expand Up @@ -369,7 +369,7 @@ source_profile=foo
}

configLoader := &vault.ConfigLoader{File: configFile}
config, err := configLoader.LoadFromProfile("foo")
config, err := configLoader.GetProfileConfig("foo")
if err != nil {
t.Fatalf("Should have found a profile: %v", err)
}
Expand Down Expand Up @@ -406,7 +406,7 @@ source_profile=root
}

configLoader := &vault.ConfigLoader{File: configFile}
config, err := configLoader.LoadFromProfile("foo")
config, err := configLoader.GetProfileConfig("foo")
if err != nil {
t.Fatalf("Should have found a profile: %v", err)
}
Expand Down Expand Up @@ -441,7 +441,7 @@ func TestSetSessionTags(t *testing.T) {
}

for _, tc := range testCases {
config := vault.Config{}
config := vault.ProfileConfig{}
err := config.SetSessionTags(tc.stringValue)
if tc.ok {
if err != nil {
Expand Down Expand Up @@ -473,7 +473,7 @@ func TestSetTransitiveSessionTags(t *testing.T) {
}

for _, tc := range testCases {
config := vault.Config{}
config := vault.ProfileConfig{}
config.SetTransitiveSessionTags(tc.stringValue)
if !reflect.DeepEqual(tc.expected, config.TransitiveSessionTags) {
t.Fatalf("Expected TransitiveSessionTags: %+v, got %+v", tc.expected, config.TransitiveSessionTags)
Expand All @@ -496,7 +496,7 @@ transitive_session_tags = tagOne ,tagTwo,tagThree
t.Fatal(err)
}
configLoader := &vault.ConfigLoader{File: configFile, ActiveProfile: "tagged"}
config, err := configLoader.LoadFromProfile("tagged")
config, err := configLoader.GetProfileConfig("tagged")
if err != nil {
t.Fatalf("Should have found a profile: %v", err)
}
Expand Down Expand Up @@ -533,7 +533,7 @@ transitive_session_tags = tagOne ,tagTwo,tagThree
t.Fatal(err)
}
configLoader := &vault.ConfigLoader{File: configFile, ActiveProfile: "tagged"}
config, err := configLoader.LoadFromProfile("tagged")
config, err := configLoader.GetProfileConfig("tagged")
if err != nil {
t.Fatalf("Should have found a profile: %v", err)
}
Expand Down Expand Up @@ -578,7 +578,7 @@ source_profile = interim
t.Fatal(err)
}
configLoader := &vault.ConfigLoader{File: configFile, ActiveProfile: "target"}
config, err := configLoader.LoadFromProfile("target")
config, err := configLoader.GetProfileConfig("target")
if err != nil {
t.Fatalf("Should have found a profile: %v", err)
}
Expand Down Expand Up @@ -640,13 +640,13 @@ credential_process = true
configFile, _ := vault.LoadConfig(f)
configLoader := &vault.ConfigLoader{File: configFile}

config, _ := configLoader.LoadFromProfile("foo:staging")
config, _ := configLoader.GetProfileConfig("foo:staging")
err := config.Validate()
if err != nil {
t.Fatalf("Should have validated: %v", err)
}

config, _ = configLoader.LoadFromProfile("foo:production")
config, _ = configLoader.GetProfileConfig("foo:production")
err = config.Validate()
if err == nil {
t.Fatalf("Should have failed validation: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion vault/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (m *Mfa) GetMfaSerial() string {
return m.mfaSerial
}

func NewMfa(config *Config) *Mfa {
func NewMfa(config *ProfileConfig) *Mfa {
m := Mfa{
mfaSerial: config.MfaSerial,
}
Expand Down
Loading

0 comments on commit cf78af3

Please sign in to comment.