diff --git a/aws/sts.go b/aws/sts.go new file mode 100644 index 0000000..22f3583 --- /dev/null +++ b/aws/sts.go @@ -0,0 +1,39 @@ +package aws + +import ( + "fmt" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/sts" +) + +func AssumeSAMLRole(PrincipalArn, RoleArn, SAMLAssertion string) (*Credentials, error) { + input := sts.AssumeRoleWithSAMLInput{ + PrincipalArn: aws.String(PrincipalArn), + RoleArn: aws.String(RoleArn), + SAMLAssertion: aws.String(SAMLAssertion), + } + + sess := session.Must(session.NewSession()) + svc := sts.New(sess) + + aResp, err := svc.AssumeRoleWithSAML(&input) + if err != nil { + return nil, fmt.Errorf("assuming role: %v", err) + } + + keyID := *aResp.Credentials.AccessKeyId + secretKey := *aResp.Credentials.SecretAccessKey + sessionToken := *aResp.Credentials.SessionToken + expiration := *aResp.Credentials.Expiration + + creds := Credentials{ + AccessKeyID: keyID, + SecretAccessKey: secretKey, + SessionToken: sessionToken, + Expiration: expiration, + } + + return &creds, nil +} diff --git a/okta/get.go b/okta/get.go index e2cd4a7..394ca98 100644 --- a/okta/get.go +++ b/okta/get.go @@ -3,18 +3,15 @@ package okta import ( "fmt" - awsprovider "github.com/allcloud-io/clisso/aws" + "github.com/allcloud-io/clisso/aws" "github.com/allcloud-io/clisso/config" "github.com/allcloud-io/clisso/saml" "github.com/allcloud-io/clisso/spinner" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/sts" "github.com/howeyc/gopass" ) // Get gets temporary credentials for the given app. -func Get(app, provider string) (*awsprovider.Credentials, error) { +func Get(app, provider string) (*aws.Credentials, error) { // Get provider config p, err := config.GetOktaProvider(provider) if err != nil { @@ -102,34 +99,9 @@ func Get(app, provider string) (*awsprovider.Credentials, error) { return nil, err } - // Assume role - input := sts.AssumeRoleWithSAMLInput{ - PrincipalArn: aws.String(arn.Provider), - RoleArn: aws.String(arn.Role), - SAMLAssertion: aws.String(*samlAssertion), - } - - sess := session.Must(session.NewSession()) - svc := sts.New(sess) - s.Start() - aResp, err := svc.AssumeRoleWithSAML(&input) + creds, err := aws.AssumeSAMLRole(arn.Provider, arn.Role, *samlAssertion) s.Stop() - if err != nil { - return nil, fmt.Errorf("assuming role: %v", err) - } - - keyID := *aResp.Credentials.AccessKeyId - secretKey := *aResp.Credentials.SecretAccessKey - sessionToken := *aResp.Credentials.SessionToken - expiration := *aResp.Credentials.Expiration - - creds := awsprovider.Credentials{ - AccessKeyID: keyID, - SecretAccessKey: secretKey, - SessionToken: sessionToken, - Expiration: expiration, - } - return &creds, nil + return creds, err } diff --git a/onelogin/get.go b/onelogin/get.go index 097476b..629e86c 100644 --- a/onelogin/get.go +++ b/onelogin/get.go @@ -4,13 +4,10 @@ import ( "fmt" "time" - awsprovider "github.com/allcloud-io/clisso/aws" + "github.com/allcloud-io/clisso/aws" "github.com/allcloud-io/clisso/config" "github.com/allcloud-io/clisso/saml" "github.com/allcloud-io/clisso/spinner" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/sts" "github.com/howeyc/gopass" ) @@ -29,7 +26,7 @@ const ( // Get gets temporary credentials for the given app. // TODO Move AWS logic outside this function. -func Get(app, provider string) (*awsprovider.Credentials, error) { +func Get(app, provider string) (*aws.Credentials, error) { // Read config p, err := config.GetOneLoginProvider(provider) if err != nil { @@ -185,34 +182,9 @@ func Get(app, provider string) (*awsprovider.Credentials, error) { return nil, err } - // Assume role - pAssumeRole := sts.AssumeRoleWithSAMLInput{ - PrincipalArn: aws.String(arn.Provider), - RoleArn: aws.String(arn.Role), - SAMLAssertion: aws.String(rMfa.Data), - } - - sess := session.Must(session.NewSession()) - svc := sts.New(sess) - s.Start() - resp, err := svc.AssumeRoleWithSAML(&pAssumeRole) + creds, err := aws.AssumeSAMLRole(arn.Provider, arn.Role, rMfa.Data) s.Stop() - if err != nil { - return nil, fmt.Errorf("assuming role: %v", err) - } - - keyID := *resp.Credentials.AccessKeyId - secretKey := *resp.Credentials.SecretAccessKey - sessionToken := *resp.Credentials.SessionToken - expiration := *resp.Credentials.Expiration - - creds := awsprovider.Credentials{ - AccessKeyID: keyID, - SecretAccessKey: secretKey, - SessionToken: sessionToken, - Expiration: expiration, - } - return &creds, nil + return creds, err }