From 76ac14baae248bb063921556e7ffd9277552363a Mon Sep 17 00:00:00 2001 From: Alex Budkar Date: Sat, 25 Mar 2023 21:51:07 -0700 Subject: [PATCH] Support for aws-vault exec --server on WSL Currently when running on WSL the only supported prompt driver is terminal Terminal prompt is not compatible with the server mode To workaround we run credentials server on windows host and take advatage of windows creds storage and UX On WSL linux we run proxy command that talks to the credentials server running on host Because we don't need to run proxy on the windows host we made it optional If proxy is disabled we would not need priveledge elevation --- cli/exec.go | 41 ++++++++++--- cli/proxy.go | 41 ++++++++++++- server/ec2alias_bsd.go | 4 ++ server/ec2alias_linux.go | 10 ++- server/ec2alias_windows.go | 30 +++++++++ server/ec2proxy.go | 4 +- server/ec2server.go | 121 ++++++++++++++++++++++++++++--------- 7 files changed, 212 insertions(+), 39 deletions(-) diff --git a/cli/exec.go b/cli/exec.go index 47f7f4c54..c966b6c70 100644 --- a/cli/exec.go +++ b/cli/exec.go @@ -27,6 +27,8 @@ type ExecCommandInput struct { Args []string StartEc2Server bool StartEcsServer bool + DontStartProxy bool + IsWsl bool Lazy bool JSONDeprecated bool Config vault.ProfileConfig @@ -98,6 +100,15 @@ func ConfigureExecCommand(app *kingpin.Application, a *AwsVault) { cmd.Flag("ec2-server", "Run a EC2 metadata server in the background for credentials"). BoolVar(&input.StartEc2Server) + cmd.Flag("no-proxy", "If set will not start local proxy"). + BoolVar(&input.DontStartProxy) + + //goland:noinspection GoBoolExpressions + if runtime.GOOS == "windows" { + cmd.Flag("wsl", "If set will bind to wsl interface and allow requests from wsl"). + BoolVar(&input.IsWsl) + } + cmd.Flag("ecs-server", "Run a ECS credential server in the background for credentials (the SDK or app must support AWS_CONTAINER_CREDENTIALS_FULL_URI)"). BoolVar(&input.StartEcsServer) @@ -186,17 +197,33 @@ func ExecCommand(input ExecCommandInput, f *vault.ConfigFile, keyring keyring.Ke cmdEnv := createEnv(input.ProfileName, config.Region) if input.StartEc2Server { - if server.IsProxyRunning() { - return 0, fmt.Errorf("Another process is already bound to 169.254.169.254:80") + if !input.DontStartProxy { + if server.IsProxyRunning() { + return 0, fmt.Errorf("Another process is already bound to 169.254.169.254:80") + } + + printHelpMessage("Warning: Starting a local EC2 credential server on 169.254.169.254:80; AWS credentials will be accessible to any process while it is running", input.ShowHelpMessages) + if err := server.StartEc2EndpointProxyServerProcess(); err != nil { + return 0, err + } + defer server.StopProxy() } - printHelpMessage("Warning: Starting a local EC2 credential server on 169.254.169.254:80; AWS credentials will be accessible to any process while it is running", input.ShowHelpMessages) - if err := server.StartEc2EndpointProxyServerProcess(); err != nil { - return 0, err + extraParams := make([]server.Ec2ServerParameter, 0) + if input.IsWsl { + ip, net, err := server.GetWslAddressAndNetwork() + if err != nil { + return 1, err + } + extraParams = append(extraParams, + server.WithEc2ServerAddress(ip.String()+":"+server.DefaultEc2CredentialsServerPort), + server.WithEc2ServerAllowedNetwork(*net), + ) } - defer server.StopProxy() - if err = server.StartEc2CredentialsServer(context.TODO(), credsProvider, config.Region); err != nil { + serverParams := server.NewEc2ServerParameters(config.Region, extraParams...) + + if err = server.StartEc2CredentialsServer(context.TODO(), credsProvider, serverParams); err != nil { return 0, fmt.Errorf("Failed to start credential server: %w", err) } printHelpMessage(subshellHelp, input.ShowHelpMessages) diff --git a/cli/proxy.go b/cli/proxy.go index 1ab260614..92ee494d4 100644 --- a/cli/proxy.go +++ b/cli/proxy.go @@ -1,8 +1,13 @@ package cli import ( + "fmt" + "net" "os" + "os/exec" "os/signal" + "runtime" + "strings" "syscall" "github.com/99designs/aws-vault/v7/server" @@ -11,7 +16,8 @@ import ( func ConfigureProxyCommand(app *kingpin.Application) { stop := false - + isWsl := false + serverAddress := "" cmd := app.Command("proxy", "Start a proxy for the ec2 instance role server locally."). Alias("server"). Hidden() @@ -19,16 +25,47 @@ func ConfigureProxyCommand(app *kingpin.Application) { cmd.Flag("stop", "Stop the proxy"). BoolVar(&stop) + cmd.Flag("credentials-server-address", "Server address"). + Default(server.DefaultEc2CredentialsServerAddr). + Hidden(). + StringVar(&serverAddress) + + //goland:noinspection GoBoolExpressions + if runtime.GOOS == "linux" { + cmd.Flag("wsl", "Proxy to credentials server running on Windows host"). + BoolVar(&isWsl) + } + cmd.Action(func(*kingpin.ParseContext) error { if stop { server.StopProxy() return nil } handleSigTerm() - return server.StartProxy() + if (serverAddress == server.DefaultEc2CredentialsServerAddr) && isWsl { + ip, err := getWslHost() + if err != nil { + return err + } + serverAddress = ip.String() + ":" + server.DefaultEc2CredentialsServerPort + } + return server.StartProxy(serverAddress) }) } +func getWslHost() (net.IP, error) { + out, err := exec.Command("ip", "route").CombinedOutput() + if err != nil { + return net.IP{}, err + } + for _, line := range strings.Split(string(out), "\n") { + if strings.Contains(line, "default") { + return net.ParseIP(strings.Split(line, " ")[2]), nil + } + } + return nil, fmt.Errorf("unable to find default gateway") +} + func handleSigTerm() { // shutdown c := make(chan os.Signal, 1) diff --git a/server/ec2alias_bsd.go b/server/ec2alias_bsd.go index 2ca198241..3f472cb4d 100644 --- a/server/ec2alias_bsd.go +++ b/server/ec2alias_bsd.go @@ -5,6 +5,10 @@ package server import "os/exec" +func GetWslAddressAndNetwork() (net.IP, *net.IPNet, error) { + return net.IP{}, net.IPNet{}, fmt.Errorf("WSL is a Windows only feature") +} + func installEc2EndpointNetworkAlias() ([]byte, error) { return exec.Command("ifconfig", "lo0", "alias", "169.254.169.254").CombinedOutput() } diff --git a/server/ec2alias_linux.go b/server/ec2alias_linux.go index a137c292a..f5e014b3c 100644 --- a/server/ec2alias_linux.go +++ b/server/ec2alias_linux.go @@ -3,7 +3,15 @@ package server -import "os/exec" +import ( + "fmt" + "net" + "os/exec" +) + +func GetWslAddressAndNetwork() (net.IP, *net.IPNet, error) { + return net.IP{}, net.IPNet{}, fmt.Errorf("WSL is a Windows only feature") +} func installEc2EndpointNetworkAlias() ([]byte, error) { return exec.Command("ip", "addr", "add", "169.254.169.254/24", "dev", "lo", "label", "lo:0").CombinedOutput() diff --git a/server/ec2alias_windows.go b/server/ec2alias_windows.go index adf948f75..34e86c70e 100644 --- a/server/ec2alias_windows.go +++ b/server/ec2alias_windows.go @@ -5,6 +5,7 @@ package server import ( "fmt" + "net" "os/exec" "strings" ) @@ -41,6 +42,35 @@ func runAndWrapAdminErrors(name string, arg ...string) ([]byte, error) { return out, err } +func GetWslAddressAndNetwork() (net.IP, *net.IPNet, error) { + out, err := runAndWrapAdminErrors("netsh", "interface", "ipv4", "show", "addresses", "vEthernet (WSL)") + if err != nil { + return net.IP{}, nil, err + } + ip := net.IP{} + nt := &net.IPNet{} + + lines := strings.Split(string(out), "\n") + for _, line := range lines { + if strings.Contains(line, "IP Address:") { + sip := strings.Trim(strings.Split(line, ":")[1], " \n\r\t") + if ip = net.ParseIP(sip); ip == nil { + return net.IP{}, nil, fmt.Errorf("Unable to parse IP address from WSL interface: %s", sip) + } + } + if strings.Contains(line, "Subnet Prefix:") { + snt := strings.Split(strings.Trim(strings.Split(line, ":")[1], " \n\r\t"), " ")[0] + if _, nt, err = net.ParseCIDR(snt); err != nil { + return net.IP{}, nil, fmt.Errorf("Unable to parse network from WSL interface: %s, %v", snt, err) + } + } + } + if (ip == nil) || (nt == nil) { + return net.IP{}, nil, fmt.Errorf("Unable to find IP address and network from WSL interface") + } + return ip, nt, nil +} + func installEc2EndpointNetworkAlias() ([]byte, error) { out, err := runAndWrapAdminErrors("netsh", "interface", "ipv4", "add", "address", "Loopback Pseudo-Interface 1", "169.254.169.254", "255.255.0.0") if msgFound(alreadyRegisteredLocalised, string(out)) { diff --git a/server/ec2proxy.go b/server/ec2proxy.go index ca8b92882..e28a380b1 100644 --- a/server/ec2proxy.go +++ b/server/ec2proxy.go @@ -19,8 +19,8 @@ const ( // StartProxy starts a http proxy server that listens on the standard EC2 Instance Metadata endpoint http://169.254.169.254:80/ // and forwards requests through to the running `aws-vault exec` command -func StartProxy() error { - var localServerURL, err = url.Parse(fmt.Sprintf("http://%s/", ec2CredentialsServerAddr)) +func StartProxy(serverAddress string) error { + var localServerURL, err = url.Parse(fmt.Sprintf("http://%s/", serverAddress)) if err != nil { return err } diff --git a/server/ec2server.go b/server/ec2server.go index 8ef08db97..9f02d15e2 100644 --- a/server/ec2server.go +++ b/server/ec2server.go @@ -13,23 +13,71 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" ) -const ec2CredentialsServerAddr = "127.0.0.1:9099" +const DefaultEc2CredentialsServerIp = "127.0.0.1" +const DefaultEc2CredentialsServerPort = "9099" +const DefaultEc2CredentialsServerAddr = DefaultEc2CredentialsServerIp + ":" + DefaultEc2CredentialsServerPort + +type Ec2ServerParameters struct { + region string + serverAddress string + allowedNetworks []net.IPNet +} + +type Ec2ServerParameter interface { + apply(*Ec2ServerParameters) +} + +type ec2ServerAddress struct { + serverAddress string +} + +func (p *ec2ServerAddress) apply(params *Ec2ServerParameters) { + params.serverAddress = p.serverAddress +} + +func WithEc2ServerAddress(addr string) Ec2ServerParameter { + return &ec2ServerAddress{serverAddress: addr} +} + +type ec2ServerAllowedAddress struct { + net net.IPNet +} + +func (p *ec2ServerAllowedAddress) apply(params *Ec2ServerParameters) { + params.allowedNetworks = append(params.allowedNetworks, p.net) +} + +func WithEc2ServerAllowedNetwork(net net.IPNet) Ec2ServerParameter { + return &ec2ServerAllowedAddress{net: net} +} + +func NewEc2ServerParameters(region string, params ...Ec2ServerParameter) *Ec2ServerParameters { + result := &Ec2ServerParameters{ + region: region, + serverAddress: DefaultEc2CredentialsServerAddr, + allowedNetworks: make([]net.IPNet, 0), + } + for _, p := range params { + p.apply(result) + } + return result +} // StartEc2CredentialsServer starts a EC2 Instance Metadata server and endpoint proxy -func StartEc2CredentialsServer(ctx context.Context, credsProvider aws.CredentialsProvider, region string) error { +func StartEc2CredentialsServer(ctx context.Context, credsProvider aws.CredentialsProvider, params *Ec2ServerParameters) error { credsCache := aws.NewCredentialsCache(credsProvider) // pre-fetch credentials so that we can respond quickly to the first request // SDKs seem to very aggressively timeout _, _ = credsCache.Retrieve(ctx) - go startEc2CredentialsServer(credsCache, region) + go startEc2CredentialsServer(credsCache, params) return nil } -func startEc2CredentialsServer(credsProvider aws.CredentialsProvider, region string) { - log.Printf("Starting EC2 Instance Metadata server on %s", ec2CredentialsServerAddr) +func startEc2CredentialsServer(credsProvider aws.CredentialsProvider, params *Ec2ServerParameters) { + log.Printf("Starting EC2 Instance Metadata server on %s", params.serverAddress) router := http.NewServeMux() router.HandleFunc("/latest/meta-data/iam/security-credentials/", func(w http.ResponseWriter, r *http.Request) { @@ -48,40 +96,59 @@ func startEc2CredentialsServer(credsProvider aws.CredentialsProvider, region str // used by AWS SDK to determine region router.HandleFunc("/latest/meta-data/dynamic/instance-identity/document", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, `{"region": "`+region+`"}`) + fmt.Fprintf(w, `{"region": "`+params.region+`"}`) }) router.HandleFunc("/latest/meta-data/iam/security-credentials/local-credentials", credsHandler(credsProvider)) - log.Fatalln(http.ListenAndServe(ec2CredentialsServerAddr, withLogging(withSecurityChecks(router)))) + log.Fatalln(http.ListenAndServe(params.serverAddress, withLogging(&withSecurityChecks{params, router}))) +} + +type withSecurityChecks struct { + *Ec2ServerParameters + next *http.ServeMux } // withSecurityChecks is middleware to protect the server from attack vectors -func withSecurityChecks(next *http.ServeMux) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - // Check the remote ip is from the loopback, otherwise clients on the same network segment could - // potentially route traffic via 169.254.169.254:80 - // See https://developer.apple.com/library/content/qa/qa1357/_index.html - ip, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - if !net.ParseIP(ip).IsLoopback() { - http.Error(w, "Access denied from non-localhost address", http.StatusUnauthorized) - return +func (sc *withSecurityChecks) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Check the remote ip is from the loopback, otherwise clients on the same network segment could + // potentially route traffic via 169.254.169.254:80 + // See https://developer.apple.com/library/content/qa/qa1357/_index.html + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + checkIp := func() bool { + remoteIp := net.ParseIP(ip) + if remoteIp == nil { + return false } - // Check that the request is to 169.254.169.254 - // Without this it's possible for an attacker to mount a DNS rebinding attack - // See https://github.com/99designs/aws-vault/issues/578 - if r.Host != ec2MetadataEndpointIP && r.Host != ec2MetadataEndpointAddr { - http.Error(w, fmt.Sprintf("Access denied for host '%s'", r.Host), http.StatusUnauthorized) - return + for _, allowedNetwork := range sc.allowedNetworks { + if allowedNetwork.Contains(remoteIp) { + return true + } } - next.ServeHTTP(w, r) + return remoteIp.IsLoopback() + } + + if !checkIp() { + http.Error(w, "Access denied from not allowed address", http.StatusUnauthorized) + return + } + + // Check that the request is to 169.254.169.254 + // Without this it's possible for an attacker to mount a DNS rebinding attack + // See https://github.com/99designs/aws-vault/issues/578 + if r.Host != ec2MetadataEndpointIP && r.Host != ec2MetadataEndpointAddr { + http.Error(w, fmt.Sprintf("Access denied for host '%s'", r.Host), http.StatusUnauthorized) + return } + + sc.next.ServeHTTP(w, r) } func credsHandler(credsProvider aws.CredentialsProvider) http.HandlerFunc {