diff --git a/cli/exec.go b/cli/exec.go index 343974916..a2c84bd89 100644 --- a/cli/exec.go +++ b/cli/exec.go @@ -5,7 +5,7 @@ import ( "fmt" "log" "os" - osexec "os/exec" + "os/exec" "os/signal" "runtime" "strings" @@ -15,7 +15,6 @@ import ( "github.com/99designs/aws-vault/server" "github.com/99designs/aws-vault/vault" "github.com/99designs/keyring" - "github.com/aws/aws-sdk-go/aws/credentials" "gopkg.in/alecthomas/kingpin.v2" ) @@ -24,8 +23,7 @@ type ExecCommandInput struct { Command string Args []string Keyring keyring.Keyring - StartEc2Server bool - StartEcsServer bool + StartServer bool CredentialHelper bool Config vault.Config SessionDuration time.Duration @@ -38,7 +36,7 @@ type AwsCredentialHelperData struct { Version int `json:"Version"` AccessKeyID string `json:"AccessKeyId"` SecretAccessKey string `json:"SecretAccessKey"` - SessionToken string `json:"SessionToken,omitempty"` + SessionToken string `json:"SessionToken"` Expiration string `json:"Expiration,omitempty"` } @@ -66,17 +64,9 @@ func ConfigureExecCommand(app *kingpin.Application) { Short('j'). BoolVar(&input.CredentialHelper) - cmd.Flag("server", "Run a server in the background for credentials"). + cmd.Flag("server", "Run the server in the background for credentials"). Short('s'). - BoolVar(&input.StartEc2Server) - - cmd.Flag("ec2-server", "Run a EC2 metadata server in the background for credentials"). - Hidden(). - BoolVar(&input.StartEc2Server) - - cmd.Flag("ecs-server", "Run a ECS credential server in the background for credentials"). - Hidden(). - BoolVar(&input.StartEcsServer) + BoolVar(&input.StartServer) cmd.Arg("profile", "Name of the profile"). Required(). @@ -95,7 +85,7 @@ func ConfigureExecCommand(app *kingpin.Application) { input.Config.MfaPromptMethod = GlobalFlags.PromptDriver input.Config.NonChainedGetSessionTokenDuration = input.SessionDuration input.Config.AssumeRoleDuration = input.SessionDuration - app.FatalIfError(ExecCommand(input), "") + app.FatalIfError(ExecCommand(input), "exec") return nil }) } @@ -105,23 +95,8 @@ func ExecCommand(input ExecCommandInput) error { return fmt.Errorf("aws-vault sessions should be nested with care, unset $AWS_VAULT to force") } - if input.StartEc2Server && input.StartEcsServer { - return fmt.Errorf("Can't use --server with --ecs-server") - } - if input.StartEc2Server && input.CredentialHelper { - return fmt.Errorf("Can't use --server with --json") - } - if input.StartEc2Server && input.NoSession { - return fmt.Errorf("Can't use --server with --no-session") - } - if input.StartEcsServer && input.CredentialHelper { - return fmt.Errorf("Can't use --ecs-server with --json") - } - if input.StartEcsServer && input.NoSession { - return fmt.Errorf("Can't use --ecs-server with --no-session") - } - vault.UseSession = !input.NoSession + setEnv := true configLoader.BaseConfig = input.Config configLoader.ActiveProfile = input.ProfileName @@ -136,125 +111,83 @@ func ExecCommand(input ExecCommandInput) error { return fmt.Errorf("Error getting temporary credentials: %w", err) } - if input.StartEc2Server { - return execEc2Server(input, config, creds) - } - - if input.StartEcsServer { - return execEcsServer(input, config, creds) - } - - if input.CredentialHelper { - return execCredentialHelper(input, config, creds) - } - - return execEnvironment(input, config, creds) -} - -func updateEnvForAwsVault(env environ, profileName string, region string) environ { - env.Unset("AWS_ACCESS_KEY_ID") - env.Unset("AWS_SECRET_ACCESS_KEY") - env.Unset("AWS_SESSION_TOKEN") - env.Unset("AWS_SECURITY_TOKEN") - env.Unset("AWS_CREDENTIAL_FILE") - env.Unset("AWS_DEFAULT_PROFILE") - env.Unset("AWS_PROFILE") - env.Unset("AWS_SDK_LOAD_CONFIG") - - env.Set("AWS_VAULT", profileName) - - if region != "" { - log.Printf("Setting subprocess env: AWS_DEFAULT_REGION=%s, AWS_REGION=%s", region, region) - env.Set("AWS_DEFAULT_REGION", region) - env.Set("AWS_REGION", region) - } - - return env -} - -func execEc2Server(input ExecCommandInput, config *vault.Config, creds *credentials.Credentials) error { - if err := server.StartEc2CredentialsServer(creds, config.Region); err != nil { - return fmt.Errorf("Failed to start credential server: %w", err) - } - - env := environ(os.Environ()) - env = updateEnvForAwsVault(env, input.ProfileName, config.Region) - - return execCmd(input.Command, input.Args, env) -} - -func execEcsServer(input ExecCommandInput, config *vault.Config, creds *credentials.Credentials) error { - uri, token, err := server.StartEcsCredentialServer(creds) - if err != nil { - return fmt.Errorf("Failed to start credential server: %w", err) - } - - env := environ(os.Environ()) - env = updateEnvForAwsVault(env, input.ProfileName, config.Region) - - log.Println("Setting subprocess env AWS_CONTAINER_CREDENTIALS_FULL_URI, AWS_CONTAINER_AUTHORIZATION_TOKEN") - env.Set("AWS_CONTAINER_CREDENTIALS_FULL_URI", uri) - env.Set("AWS_CONTAINER_AUTHORIZATION_TOKEN", token) - - return execCmd(input.Command, input.Args, env) -} - -func execCredentialHelper(input ExecCommandInput, config *vault.Config, creds *credentials.Credentials) error { val, err := creds.Get() if err != nil { return fmt.Errorf("Failed to get credentials for %s: %w", input.ProfileName, err) } - credentialData := AwsCredentialHelperData{ - Version: 1, - AccessKeyID: val.AccessKeyID, - SecretAccessKey: val.SecretAccessKey, - } - if val.SessionToken != "" { - credentialData.SessionToken = val.SessionToken - } - if credsExpiresAt, err := creds.ExpiresAt(); err == nil { - credentialData.Expiration = credsExpiresAt.Format("2006-01-02T15:04:05Z") - } - - json, err := json.Marshal(&credentialData) - if err != nil { - return fmt.Errorf("Error creating credential json: %w", err) - } - - fmt.Print(string(json)) - - return nil -} - -func execEnvironment(input ExecCommandInput, config *vault.Config, creds *credentials.Credentials) error { - val, err := creds.Get() - if err != nil { - return fmt.Errorf("Failed to get credentials for %s: %w", input.ProfileName, err) + if input.StartServer { + if err := server.StartCredentialsServer(creds); err != nil { + return fmt.Errorf("Failed to start credential server: %w", err) + } + setEnv = false } - env := environ(os.Environ()) - env = updateEnvForAwsVault(env, input.ProfileName, config.Region) + if input.CredentialHelper { + credentialData := AwsCredentialHelperData{ + Version: 1, + AccessKeyID: val.AccessKeyID, + SecretAccessKey: val.SecretAccessKey, + SessionToken: val.SessionToken, + } + if !input.NoSession { + credsExprest, err := creds.ExpiresAt() + if err != nil { + return fmt.Errorf("Error getting credential expiration: %w", err) + } + credentialData.Expiration = credsExprest.Format("2006-01-02T15:04:05Z") + } + json, err := json.Marshal(&credentialData) + if err != nil { + return fmt.Errorf("Error creating credential json: %w", err) + } + fmt.Print(string(json)) + } else { + + env := environ(os.Environ()) + env.Set("AWS_VAULT", input.ProfileName) + + env.Unset("AWS_ACCESS_KEY_ID") + env.Unset("AWS_SECRET_ACCESS_KEY") + env.Unset("AWS_CREDENTIAL_FILE") + env.Unset("AWS_DEFAULT_PROFILE") + env.Unset("AWS_PROFILE") + + if config.Region != "" { + log.Printf("Setting subprocess env: AWS_DEFAULT_REGION=%s, AWS_REGION=%s", config.Region, config.Region) + env.Set("AWS_DEFAULT_REGION", config.Region) + env.Set("AWS_REGION", config.Region) + } - log.Println("Setting subprocess env: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY") - env.Set("AWS_ACCESS_KEY_ID", val.AccessKeyID) - env.Set("AWS_SECRET_ACCESS_KEY", val.SecretAccessKey) + if setEnv { + log.Println("Setting subprocess env: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY") + env.Set("AWS_ACCESS_KEY_ID", val.AccessKeyID) + env.Set("AWS_SECRET_ACCESS_KEY", val.SecretAccessKey) + + if val.SessionToken != "" { + log.Println("Setting subprocess env: AWS_SESSION_TOKEN, AWS_SECURITY_TOKEN") + env.Set("AWS_SESSION_TOKEN", val.SessionToken) + env.Set("AWS_SECURITY_TOKEN", val.SessionToken) + expiration, err := creds.ExpiresAt() + if err == nil { + log.Println("Setting subprocess env: AWS_SESSION_EXPIRATION") + env.Set("AWS_SESSION_EXPIRATION", expiration.Format(time.RFC3339)) + } + } + } - if val.SessionToken != "" { - log.Println("Setting subprocess env: AWS_SESSION_TOKEN, AWS_SECURITY_TOKEN") - env.Set("AWS_SESSION_TOKEN", val.SessionToken) - env.Set("AWS_SECURITY_TOKEN", val.SessionToken) - } - if expiration, err := creds.ExpiresAt(); err == nil { - log.Println("Setting subprocess env: AWS_SESSION_EXPIRATION") - env.Set("AWS_SESSION_EXPIRATION", expiration.Format(time.RFC3339)) - } + if input.StartServer { + err = execCmd(input.Command, input.Args, env) + } else { + err = execSyscall(input.Command, input.Args, env) + } - if !supportsExecSyscall() { - return execCmd(input.Command, input.Args, env) + if err != nil { + return fmt.Errorf("Error execing process: %w", err) + } } - return execSyscall(input.Command, input.Args, env) + return nil } // environ is a slice of strings representing the environment, in the form "key=value". @@ -278,9 +211,7 @@ func (e *environ) Set(key, val string) { } func execCmd(command string, args []string, env []string) error { - log.Printf("Starting child process: %s %s", command, strings.Join(args, " ")) - - cmd := osexec.Command(command, args...) + cmd := exec.Command(command, args...) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr @@ -290,7 +221,7 @@ func execCmd(command string, args []string, env []string) error { signal.Notify(sigChan) if err := cmd.Start(); err != nil { - return err + return fmt.Errorf("Failed to start command: %v", err) } go func() { @@ -315,9 +246,11 @@ func supportsExecSyscall() bool { } func execSyscall(command string, args []string, env []string) error { - log.Printf("Exec command %s %s", command, strings.Join(args, " ")) + if !supportsExecSyscall() { + return execCmd(command, args, env) + } - argv0, err := osexec.LookPath(command) + argv0, err := exec.LookPath(command) if err != nil { return err } diff --git a/cli/server.go b/cli/server.go index 8a6c6c5bd..5050e7f0a 100644 --- a/cli/server.go +++ b/cli/server.go @@ -21,7 +21,7 @@ func ConfigureServerCommand(app *kingpin.Application) { } func ServerCommand(app *kingpin.Application, input ServerCommandInput) { - if err := server.StartEc2MetadataEndpointProxy(); err != nil { + if err := server.StartMetadataServer(); err != nil { app.Fatalf("Server failed: %v", err) } } diff --git a/server/ec2alias_bsd.go b/server/alias_bsd.go similarity index 72% rename from server/ec2alias_bsd.go rename to server/alias_bsd.go index 5d6c9870f..b514f6ed9 100644 --- a/server/ec2alias_bsd.go +++ b/server/alias_bsd.go @@ -4,6 +4,6 @@ package server import "os/exec" -func installEc2EndpointNetworkAlias() ([]byte, error) { +func installNetworkAlias() ([]byte, error) { return exec.Command("ifconfig", "lo0", "alias", "169.254.169.254").CombinedOutput() } diff --git a/server/ec2alias_linux.go b/server/alias_linux.go similarity index 74% rename from server/ec2alias_linux.go rename to server/alias_linux.go index a542a6bb8..d86a99b7a 100644 --- a/server/ec2alias_linux.go +++ b/server/alias_linux.go @@ -4,6 +4,6 @@ package server import "os/exec" -func installEc2EndpointNetworkAlias() ([]byte, error) { +func installNetworkAlias() ([]byte, error) { return exec.Command("ip", "addr", "add", "169.254.169.254/24", "dev", "lo", "label", "lo:0").CombinedOutput() } diff --git a/server/ec2alias_windows.go b/server/alias_windows.go similarity index 90% rename from server/ec2alias_windows.go rename to server/alias_windows.go index a57856feb..9cdcad339 100644 --- a/server/ec2alias_windows.go +++ b/server/alias_windows.go @@ -8,7 +8,7 @@ import ( "strings" ) -func installEc2EndpointNetworkAlias() ([]byte, error) { +func installNetworkAlias() ([]byte, error) { out, err := exec.Command("netsh", "interface", "ipv4", "add", "address", "Loopback Pseudo-Interface 1", "169.254.169.254", "255.255.0.0").CombinedOutput() if err == nil || strings.Contains(string(out), "The object already exists") { diff --git a/server/ec2.go b/server/ec2.go deleted file mode 100644 index 737c1c588..000000000 --- a/server/ec2.go +++ /dev/null @@ -1,147 +0,0 @@ -package server - -import ( - "encoding/json" - "fmt" - "log" - "net" - "net/http" - "net/http/httputil" - "net/url" - "time" - - "github.com/aws/aws-sdk-go/aws/credentials" -) - -const ( - awsTimeFormat = "2006-01-02T15:04:05Z" - ec2MetadataEndpointAddr = "169.254.169.254:80" - ec2CredentialsServerAddr = "127.0.0.1:9099" -) - -// StartEc2MetadataEndpointProxy starts a http proxy server on the standard EC2 Instance Metadata endpoint -func StartEc2MetadataEndpointProxy() error { - var localServerURL, err = url.Parse(fmt.Sprintf("http://%s/", ec2CredentialsServerAddr)) - if err != nil { - log.Fatal(err) - } - - if _, err := installEc2EndpointNetworkAlias(); err != nil { - return err - } - - l, err := net.Listen("tcp", ec2MetadataEndpointAddr) - if err != nil { - return err - } - - log.Printf("EC2 Instance Metadata endpoint proxy server running on %s", l.Addr()) - return http.Serve(l, httputil.NewSingleHostReverseProxy(localServerURL)) -} - -func isServerRunning(bind string) bool { - _, err := net.DialTimeout("tcp", bind, time.Millisecond*10) - return err == nil -} - -// StartEc2CredentialsServer starts a EC2 Instance Metadata server and endpoint proxy -func StartEc2CredentialsServer(creds *credentials.Credentials, region string) error { - if !isServerRunning(ec2MetadataEndpointAddr) { - if err := StartEc2EndpointProxyServerProcess(); err != nil { - return err - } - } - - // pre-fetch credentials so that we can respond quickly to the first request - _, _ = creds.Get() - - go startEc2CredentialsServer(creds, region) - - return nil -} - -func startEc2CredentialsServer(creds *credentials.Credentials, region string) { - - log.Printf("Starting EC2 Instance Metadata server on %s", ec2CredentialsServerAddr) - router := http.NewServeMux() - - router.HandleFunc("/latest/meta-data/iam/security-credentials/", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "local-credentials") - }) - - // The AWS Go SDK checks the instance-id endpoint to validate the existence of EC2 Metadata - router.HandleFunc("/latest/meta-data/instance-id/", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "aws-vault") - }) - - // The AWS .NET SDK checks this endpoint during obtaining credentials/refreshing them - router.HandleFunc("/latest/meta-data/iam/info/", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, `{"Code" : "Success"}`) - }) - - // used by AWS SDK to determine region - router.HandleFunc("/latest/meta-data/dynamic/instance-identity/document", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, `{"region": "`+region+`"}`) - }) - - router.HandleFunc("/latest/meta-data/iam/security-credentials/local-credentials", credsHandler(creds)) - - log.Fatalln(http.ListenAndServe(ec2CredentialsServerAddr, withLogging(withLoopbackSecurityCheck(router)))) -} - -// withLoopbackSecurityCheck is middleware to check that the request comes from the loopback device -// We must make sure the remote ip is from the loopback, otherwise clients on the same network segment could -// potentially route traffic via 169.254.169.254:80 -// See https://developer.apple.com/library/content/qa/qa1357/_index.html -func withLoopbackSecurityCheck(next *http.ServeMux) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ip, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - if !net.ParseIP(ip).IsLoopback() { - http.Error(w, "Access denied from non-localhost address", http.StatusUnauthorized) - return - } - - next.ServeHTTP(w, r) - } -} - -func credsHandler(creds *credentials.Credentials) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - log.Printf("Credentials.IsExpired() = %#v", creds.IsExpired()) - - val, err := creds.Get() - if err != nil { - http.Error(w, err.Error(), http.StatusGatewayTimeout) - return - } - credsExpiresAt, err := creds.ExpiresAt() - if err != nil { - http.Error(w, err.Error(), http.StatusGatewayTimeout) - return - } - - log.Printf("Serving credentials via http ****************%s, expiration of %s (%s)", - val.AccessKeyID[len(val.AccessKeyID)-4:], - credsExpiresAt.Format(awsTimeFormat), - time.Until(credsExpiresAt).String()) - - err = json.NewEncoder(w).Encode(map[string]interface{}{ - "Code": "Success", - "LastUpdated": time.Now().Format(awsTimeFormat), - "Type": "AWS-HMAC", - "AccessKeyId": val.AccessKeyID, - "SecretAccessKey": val.SecretAccessKey, - "Token": val.SessionToken, - "Expiration": credsExpiresAt.Format(awsTimeFormat), - }) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - } -} diff --git a/server/ecs.go b/server/ecs.go deleted file mode 100644 index 46fcc1dd8..000000000 --- a/server/ecs.go +++ /dev/null @@ -1,89 +0,0 @@ -package server - -import ( - "crypto/rand" - "encoding/base64" - "encoding/json" - "fmt" - "log" - "net" - "net/http" - - "github.com/aws/aws-sdk-go/aws/credentials" -) - -func writeErrorMessage(w http.ResponseWriter, msg string, statusCode int) { - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(statusCode) - if err := json.NewEncoder(w).Encode(map[string]string{"Message": msg}); err != nil { - log.Println(err.Error()) - } -} - -func withAuthorizationCheck(token string, next http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Authorization") != token { - writeErrorMessage(w, "invalid Authorization token", http.StatusForbidden) - return - } - next.ServeHTTP(w, r) - } -} - -// StartEcsCredentialServer starts an ECS credential server on a random port -func StartEcsCredentialServer(creds *credentials.Credentials) (string, string, error) { - listener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - return "", "", err - } - token, err := generateRandomString() - if err != nil { - return "", "", err - } - - go func() { - err := http.Serve(listener, withLogging(withAuthorizationCheck(token, ecsCredsHandler(creds)))) - // returns ErrServerClosed on graceful close - if err != http.ErrServerClosed { - log.Fatalf("ecs server: %s", err.Error()) - } - }() - - uri := fmt.Sprintf("http://%s", listener.Addr().String()) - return uri, token, nil -} - -func ecsCredsHandler(creds *credentials.Credentials) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - val, err := creds.Get() - if err != nil { - writeErrorMessage(w, err.Error(), http.StatusInternalServerError) - return - } - - credsExpiresAt, err := creds.ExpiresAt() - if err != nil { - writeErrorMessage(w, err.Error(), http.StatusInternalServerError) - return - } - - err = json.NewEncoder(w).Encode(map[string]string{ - "AccessKeyId": val.AccessKeyID, - "SecretAccessKey": val.SecretAccessKey, - "Token": val.SessionToken, - "Expiration": credsExpiresAt.Format("2006-01-02T15:04:05Z"), - }) - if err != nil { - writeErrorMessage(w, err.Error(), http.StatusInternalServerError) - return - } - } -} - -func generateRandomString() (string, error) { - b := make([]byte, 30) - if _, err := rand.Read(b); err != nil { - return "", err - } - return base64.RawURLEncoding.EncodeToString(b), nil -} diff --git a/server/httplog.go b/server/httplog.go deleted file mode 100644 index 35b4c8209..000000000 --- a/server/httplog.go +++ /dev/null @@ -1,26 +0,0 @@ -package server - -import ( - "log" - "net/http" - "time" -) - -type loggingMiddlewareResponseWriter struct { - http.ResponseWriter - Code int -} - -func (w *loggingMiddlewareResponseWriter) WriteHeader(statusCode int) { - w.Code = statusCode - w.ResponseWriter.WriteHeader(statusCode) -} - -func withLogging(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestStart := time.Now() - w2 := &loggingMiddlewareResponseWriter{w, http.StatusOK} - handler.ServeHTTP(w2, r) - log.Printf("http: %s: %d %s %s (%s)", r.RemoteAddr, w2.Code, r.Method, r.URL, time.Since(requestStart)) - }) -} diff --git a/server/ec2proxy_default.go b/server/proxy_default.go similarity index 51% rename from server/ec2proxy_default.go rename to server/proxy_default.go index 591f873be..a29d1f813 100644 --- a/server/ec2proxy_default.go +++ b/server/proxy_default.go @@ -10,8 +10,8 @@ import ( "time" ) -// StartEc2EndpointProxyServerProcess starts a `aws-vault server` process -func StartEc2EndpointProxyServerProcess() error { +// StartCredentialProxy starts a `aws-vault server` process +func StartCredentialProxy() error { log.Println("Starting `aws-vault server` in the background") cmd := exec.Command(os.Args[0], "server") cmd.Stdin = os.Stdin @@ -21,8 +21,8 @@ func StartEc2EndpointProxyServerProcess() error { return err } time.Sleep(time.Second * 1) - if !isServerRunning(ec2MetadataEndpointAddr) { - return errors.New("The EC2 Instance Metadata endpoint proxy server isn't running. Run `aws-vault server` as Administrator or root in the background and then try this command again") + if !checkServerRunning(metadataBind) { + return errors.New("The credential proxy server isn't running. Run aws-vault server as Administrator in the background and then try this command again") } return nil } diff --git a/server/ec2proxy_unix.go b/server/proxy_unix.go similarity index 70% rename from server/ec2proxy_unix.go rename to server/proxy_unix.go index 853ac8e47..34a0ec4e4 100644 --- a/server/ec2proxy_unix.go +++ b/server/proxy_unix.go @@ -8,8 +8,8 @@ import ( "os/exec" ) -// StartEc2EndpointProxyServerProcess starts a `aws-vault server` process -func StartEc2EndpointProxyServerProcess() error { +// StartCredentialProxy starts a `aws-vault server` process +func StartCredentialProxy() error { log.Println("Starting `aws-vault server` as root in the background") cmd := exec.Command("sudo", "-b", os.Args[0], "server") cmd.Stdin = os.Stdin diff --git a/server/server.go b/server/server.go new file mode 100644 index 000000000..40d23416d --- /dev/null +++ b/server/server.go @@ -0,0 +1,145 @@ +package server + +import ( + "encoding/json" + "fmt" + "io" + "log" + "net" + "net/http" + "time" + + "github.com/aws/aws-sdk-go/aws/credentials" +) + +const ( + metadataBind = "169.254.169.254:80" + awsTimeFormat = "2006-01-02T15:04:05Z" + localServerURL = "http://127.0.0.1:9099" + localServerBind = "127.0.0.1:9099" +) + +func StartMetadataServer() error { + if _, err := installNetworkAlias(); err != nil { + return err + } + + router := http.NewServeMux() + router.HandleFunc("/latest/meta-data/iam/security-credentials/", indexHandler) + router.HandleFunc("/latest/meta-data/iam/security-credentials/local-credentials", credentialsHandler) + // The AWS Go SDK checks the instance-id endpoint to validate the existence of EC2 Metadata + router.HandleFunc("/latest/meta-data/instance-id/", instanceIdHandler) + // The AWS .NET SDK checks this endpoint during obtaining credentials/refreshing them + router.HandleFunc("/latest/meta-data/iam/info/", infoHandlerStub) + + l, err := net.Listen("tcp", metadataBind) + if err != nil { + return err + } + + log.Printf("Local instance role server running on %s", l.Addr()) + return http.Serve(l, router) +} + +func infoHandlerStub(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, `{"Code" : "Success"}`) +} + +func indexHandler(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "local-credentials") +} + +func credentialsHandler(w http.ResponseWriter, r *http.Request) { + resp, err := http.Get(localServerURL) + if err != nil { + http.Error(w, err.Error(), http.StatusGatewayTimeout) + return + } + defer resp.Body.Close() + + log.Printf("Fetched credentials from %s", localServerURL) + + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + + _, err = io.Copy(w, resp.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} + +func instanceIdHandler(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "aws-vault") +} + +func checkServerRunning(bind string) bool { + _, err := net.DialTimeout("tcp", bind, time.Millisecond*10) + return err == nil +} + +func StartCredentialsServer(creds *credentials.Credentials) error { + if !checkServerRunning(metadataBind) { + if err := StartCredentialProxy(); err != nil { + return err + } + } + + log.Printf("Starting local instance role server on %s", localServerBind) + go func() { + log.Fatalln(http.ListenAndServe(localServerBind, credsHandler(creds))) + }() + + return nil +} + +func credsHandler(creds *credentials.Credentials) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + // Must make sure the remote ip is from the loopback, otherwise clients on the same network segment could + // potentially route traffic via 169.254.169.254:80 + // See https://developer.apple.com/library/content/qa/qa1357/_index.html + if !net.ParseIP(ip).IsLoopback() { + http.Error(w, "Access denied from non-localhost address", http.StatusUnauthorized) + return + } + + log.Printf("RemoteAddr = %v", r.RemoteAddr) + log.Printf("Credentials.IsExpired() = %#v", creds.IsExpired()) + + val, err := creds.Get() + if err != nil { + http.Error(w, err.Error(), http.StatusGatewayTimeout) + return + } + credsExpiresAt, err := creds.ExpiresAt() + if err != nil { + http.Error(w, err.Error(), http.StatusGatewayTimeout) + return + } + + log.Printf("Serving credentials via http ****************%s, expiration of %s (%s)", + val.AccessKeyID[len(val.AccessKeyID)-4:], + credsExpiresAt.Format(awsTimeFormat), + time.Until(credsExpiresAt).String()) + + err = json.NewEncoder(w).Encode(map[string]interface{}{ + "Code": "Success", + "LastUpdated": time.Now().Format(awsTimeFormat), + "Type": "AWS-HMAC", + "AccessKeyId": val.AccessKeyID, + "SecretAccessKey": val.SecretAccessKey, + "Token": val.SessionToken, + "Expiration": credsExpiresAt.Format(awsTimeFormat), + }) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + } +}