diff --git a/cli/exec.go b/cli/exec.go index 47f7f4c54..c966b6c70 100644 --- a/cli/exec.go +++ b/cli/exec.go @@ -27,6 +27,8 @@ type ExecCommandInput struct { Args []string StartEc2Server bool StartEcsServer bool + DontStartProxy bool + IsWsl bool Lazy bool JSONDeprecated bool Config vault.ProfileConfig @@ -98,6 +100,15 @@ func ConfigureExecCommand(app *kingpin.Application, a *AwsVault) { cmd.Flag("ec2-server", "Run a EC2 metadata server in the background for credentials"). BoolVar(&input.StartEc2Server) + cmd.Flag("no-proxy", "If set will not start local proxy"). + BoolVar(&input.DontStartProxy) + + //goland:noinspection GoBoolExpressions + if runtime.GOOS == "windows" { + cmd.Flag("wsl", "If set will bind to wsl interface and allow requests from wsl"). + BoolVar(&input.IsWsl) + } + cmd.Flag("ecs-server", "Run a ECS credential server in the background for credentials (the SDK or app must support AWS_CONTAINER_CREDENTIALS_FULL_URI)"). BoolVar(&input.StartEcsServer) @@ -186,17 +197,33 @@ func ExecCommand(input ExecCommandInput, f *vault.ConfigFile, keyring keyring.Ke cmdEnv := createEnv(input.ProfileName, config.Region) if input.StartEc2Server { - if server.IsProxyRunning() { - return 0, fmt.Errorf("Another process is already bound to 169.254.169.254:80") + if !input.DontStartProxy { + if server.IsProxyRunning() { + return 0, fmt.Errorf("Another process is already bound to 169.254.169.254:80") + } + + printHelpMessage("Warning: Starting a local EC2 credential server on 169.254.169.254:80; AWS credentials will be accessible to any process while it is running", input.ShowHelpMessages) + if err := server.StartEc2EndpointProxyServerProcess(); err != nil { + return 0, err + } + defer server.StopProxy() } - printHelpMessage("Warning: Starting a local EC2 credential server on 169.254.169.254:80; AWS credentials will be accessible to any process while it is running", input.ShowHelpMessages) - if err := server.StartEc2EndpointProxyServerProcess(); err != nil { - return 0, err + extraParams := make([]server.Ec2ServerParameter, 0) + if input.IsWsl { + ip, net, err := server.GetWslAddressAndNetwork() + if err != nil { + return 1, err + } + extraParams = append(extraParams, + server.WithEc2ServerAddress(ip.String()+":"+server.DefaultEc2CredentialsServerPort), + server.WithEc2ServerAllowedNetwork(*net), + ) } - defer server.StopProxy() - if err = server.StartEc2CredentialsServer(context.TODO(), credsProvider, config.Region); err != nil { + serverParams := server.NewEc2ServerParameters(config.Region, extraParams...) + + if err = server.StartEc2CredentialsServer(context.TODO(), credsProvider, serverParams); err != nil { return 0, fmt.Errorf("Failed to start credential server: %w", err) } printHelpMessage(subshellHelp, input.ShowHelpMessages) diff --git a/cli/proxy.go b/cli/proxy.go index 1ab260614..92ee494d4 100644 --- a/cli/proxy.go +++ b/cli/proxy.go @@ -1,8 +1,13 @@ package cli import ( + "fmt" + "net" "os" + "os/exec" "os/signal" + "runtime" + "strings" "syscall" "github.com/99designs/aws-vault/v7/server" @@ -11,7 +16,8 @@ import ( func ConfigureProxyCommand(app *kingpin.Application) { stop := false - + isWsl := false + serverAddress := "" cmd := app.Command("proxy", "Start a proxy for the ec2 instance role server locally."). Alias("server"). Hidden() @@ -19,16 +25,47 @@ func ConfigureProxyCommand(app *kingpin.Application) { cmd.Flag("stop", "Stop the proxy"). BoolVar(&stop) + cmd.Flag("credentials-server-address", "Server address"). + Default(server.DefaultEc2CredentialsServerAddr). + Hidden(). + StringVar(&serverAddress) + + //goland:noinspection GoBoolExpressions + if runtime.GOOS == "linux" { + cmd.Flag("wsl", "Proxy to credentials server running on Windows host"). + BoolVar(&isWsl) + } + cmd.Action(func(*kingpin.ParseContext) error { if stop { server.StopProxy() return nil } handleSigTerm() - return server.StartProxy() + if (serverAddress == server.DefaultEc2CredentialsServerAddr) && isWsl { + ip, err := getWslHost() + if err != nil { + return err + } + serverAddress = ip.String() + ":" + server.DefaultEc2CredentialsServerPort + } + return server.StartProxy(serverAddress) }) } +func getWslHost() (net.IP, error) { + out, err := exec.Command("ip", "route").CombinedOutput() + if err != nil { + return net.IP{}, err + } + for _, line := range strings.Split(string(out), "\n") { + if strings.Contains(line, "default") { + return net.ParseIP(strings.Split(line, " ")[2]), nil + } + } + return nil, fmt.Errorf("unable to find default gateway") +} + func handleSigTerm() { // shutdown c := make(chan os.Signal, 1) diff --git a/server/ec2alias_bsd.go b/server/ec2alias_bsd.go index 2ca198241..3f472cb4d 100644 --- a/server/ec2alias_bsd.go +++ b/server/ec2alias_bsd.go @@ -5,6 +5,10 @@ package server import "os/exec" +func GetWslAddressAndNetwork() (net.IP, *net.IPNet, error) { + return net.IP{}, net.IPNet{}, fmt.Errorf("WSL is a Windows only feature") +} + func installEc2EndpointNetworkAlias() ([]byte, error) { return exec.Command("ifconfig", "lo0", "alias", "169.254.169.254").CombinedOutput() } diff --git a/server/ec2alias_linux.go b/server/ec2alias_linux.go index a137c292a..f5e014b3c 100644 --- a/server/ec2alias_linux.go +++ b/server/ec2alias_linux.go @@ -3,7 +3,15 @@ package server -import "os/exec" +import ( + "fmt" + "net" + "os/exec" +) + +func GetWslAddressAndNetwork() (net.IP, *net.IPNet, error) { + return net.IP{}, net.IPNet{}, fmt.Errorf("WSL is a Windows only feature") +} func installEc2EndpointNetworkAlias() ([]byte, error) { return exec.Command("ip", "addr", "add", "169.254.169.254/24", "dev", "lo", "label", "lo:0").CombinedOutput() diff --git a/server/ec2alias_windows.go b/server/ec2alias_windows.go index adf948f75..34e86c70e 100644 --- a/server/ec2alias_windows.go +++ b/server/ec2alias_windows.go @@ -5,6 +5,7 @@ package server import ( "fmt" + "net" "os/exec" "strings" ) @@ -41,6 +42,35 @@ func runAndWrapAdminErrors(name string, arg ...string) ([]byte, error) { return out, err } +func GetWslAddressAndNetwork() (net.IP, *net.IPNet, error) { + out, err := runAndWrapAdminErrors("netsh", "interface", "ipv4", "show", "addresses", "vEthernet (WSL)") + if err != nil { + return net.IP{}, nil, err + } + ip := net.IP{} + nt := &net.IPNet{} + + lines := strings.Split(string(out), "\n") + for _, line := range lines { + if strings.Contains(line, "IP Address:") { + sip := strings.Trim(strings.Split(line, ":")[1], " \n\r\t") + if ip = net.ParseIP(sip); ip == nil { + return net.IP{}, nil, fmt.Errorf("Unable to parse IP address from WSL interface: %s", sip) + } + } + if strings.Contains(line, "Subnet Prefix:") { + snt := strings.Split(strings.Trim(strings.Split(line, ":")[1], " \n\r\t"), " ")[0] + if _, nt, err = net.ParseCIDR(snt); err != nil { + return net.IP{}, nil, fmt.Errorf("Unable to parse network from WSL interface: %s, %v", snt, err) + } + } + } + if (ip == nil) || (nt == nil) { + return net.IP{}, nil, fmt.Errorf("Unable to find IP address and network from WSL interface") + } + return ip, nt, nil +} + func installEc2EndpointNetworkAlias() ([]byte, error) { out, err := runAndWrapAdminErrors("netsh", "interface", "ipv4", "add", "address", "Loopback Pseudo-Interface 1", "169.254.169.254", "255.255.0.0") if msgFound(alreadyRegisteredLocalised, string(out)) { diff --git a/server/ec2proxy.go b/server/ec2proxy.go index ca8b92882..e28a380b1 100644 --- a/server/ec2proxy.go +++ b/server/ec2proxy.go @@ -19,8 +19,8 @@ const ( // StartProxy starts a http proxy server that listens on the standard EC2 Instance Metadata endpoint http://169.254.169.254:80/ // and forwards requests through to the running `aws-vault exec` command -func StartProxy() error { - var localServerURL, err = url.Parse(fmt.Sprintf("http://%s/", ec2CredentialsServerAddr)) +func StartProxy(serverAddress string) error { + var localServerURL, err = url.Parse(fmt.Sprintf("http://%s/", serverAddress)) if err != nil { return err } diff --git a/server/ec2server.go b/server/ec2server.go index 8ef08db97..9f02d15e2 100644 --- a/server/ec2server.go +++ b/server/ec2server.go @@ -13,23 +13,71 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" ) -const ec2CredentialsServerAddr = "127.0.0.1:9099" +const DefaultEc2CredentialsServerIp = "127.0.0.1" +const DefaultEc2CredentialsServerPort = "9099" +const DefaultEc2CredentialsServerAddr = DefaultEc2CredentialsServerIp + ":" + DefaultEc2CredentialsServerPort + +type Ec2ServerParameters struct { + region string + serverAddress string + allowedNetworks []net.IPNet +} + +type Ec2ServerParameter interface { + apply(*Ec2ServerParameters) +} + +type ec2ServerAddress struct { + serverAddress string +} + +func (p *ec2ServerAddress) apply(params *Ec2ServerParameters) { + params.serverAddress = p.serverAddress +} + +func WithEc2ServerAddress(addr string) Ec2ServerParameter { + return &ec2ServerAddress{serverAddress: addr} +} + +type ec2ServerAllowedAddress struct { + net net.IPNet +} + +func (p *ec2ServerAllowedAddress) apply(params *Ec2ServerParameters) { + params.allowedNetworks = append(params.allowedNetworks, p.net) +} + +func WithEc2ServerAllowedNetwork(net net.IPNet) Ec2ServerParameter { + return &ec2ServerAllowedAddress{net: net} +} + +func NewEc2ServerParameters(region string, params ...Ec2ServerParameter) *Ec2ServerParameters { + result := &Ec2ServerParameters{ + region: region, + serverAddress: DefaultEc2CredentialsServerAddr, + allowedNetworks: make([]net.IPNet, 0), + } + for _, p := range params { + p.apply(result) + } + return result +} // StartEc2CredentialsServer starts a EC2 Instance Metadata server and endpoint proxy -func StartEc2CredentialsServer(ctx context.Context, credsProvider aws.CredentialsProvider, region string) error { +func StartEc2CredentialsServer(ctx context.Context, credsProvider aws.CredentialsProvider, params *Ec2ServerParameters) error { credsCache := aws.NewCredentialsCache(credsProvider) // pre-fetch credentials so that we can respond quickly to the first request // SDKs seem to very aggressively timeout _, _ = credsCache.Retrieve(ctx) - go startEc2CredentialsServer(credsCache, region) + go startEc2CredentialsServer(credsCache, params) return nil } -func startEc2CredentialsServer(credsProvider aws.CredentialsProvider, region string) { - log.Printf("Starting EC2 Instance Metadata server on %s", ec2CredentialsServerAddr) +func startEc2CredentialsServer(credsProvider aws.CredentialsProvider, params *Ec2ServerParameters) { + log.Printf("Starting EC2 Instance Metadata server on %s", params.serverAddress) router := http.NewServeMux() router.HandleFunc("/latest/meta-data/iam/security-credentials/", func(w http.ResponseWriter, r *http.Request) { @@ -48,40 +96,59 @@ func startEc2CredentialsServer(credsProvider aws.CredentialsProvider, region str // used by AWS SDK to determine region router.HandleFunc("/latest/meta-data/dynamic/instance-identity/document", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, `{"region": "`+region+`"}`) + fmt.Fprintf(w, `{"region": "`+params.region+`"}`) }) router.HandleFunc("/latest/meta-data/iam/security-credentials/local-credentials", credsHandler(credsProvider)) - log.Fatalln(http.ListenAndServe(ec2CredentialsServerAddr, withLogging(withSecurityChecks(router)))) + log.Fatalln(http.ListenAndServe(params.serverAddress, withLogging(&withSecurityChecks{params, router}))) +} + +type withSecurityChecks struct { + *Ec2ServerParameters + next *http.ServeMux } // withSecurityChecks is middleware to protect the server from attack vectors -func withSecurityChecks(next *http.ServeMux) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - // Check the remote ip is from the loopback, otherwise clients on the same network segment could - // potentially route traffic via 169.254.169.254:80 - // See https://developer.apple.com/library/content/qa/qa1357/_index.html - ip, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - if !net.ParseIP(ip).IsLoopback() { - http.Error(w, "Access denied from non-localhost address", http.StatusUnauthorized) - return +func (sc *withSecurityChecks) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Check the remote ip is from the loopback, otherwise clients on the same network segment could + // potentially route traffic via 169.254.169.254:80 + // See https://developer.apple.com/library/content/qa/qa1357/_index.html + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + checkIp := func() bool { + remoteIp := net.ParseIP(ip) + if remoteIp == nil { + return false } - // Check that the request is to 169.254.169.254 - // Without this it's possible for an attacker to mount a DNS rebinding attack - // See https://github.com/99designs/aws-vault/issues/578 - if r.Host != ec2MetadataEndpointIP && r.Host != ec2MetadataEndpointAddr { - http.Error(w, fmt.Sprintf("Access denied for host '%s'", r.Host), http.StatusUnauthorized) - return + for _, allowedNetwork := range sc.allowedNetworks { + if allowedNetwork.Contains(remoteIp) { + return true + } } - next.ServeHTTP(w, r) + return remoteIp.IsLoopback() + } + + if !checkIp() { + http.Error(w, "Access denied from not allowed address", http.StatusUnauthorized) + return + } + + // Check that the request is to 169.254.169.254 + // Without this it's possible for an attacker to mount a DNS rebinding attack + // See https://github.com/99designs/aws-vault/issues/578 + if r.Host != ec2MetadataEndpointIP && r.Host != ec2MetadataEndpointAddr { + http.Error(w, fmt.Sprintf("Access denied for host '%s'", r.Host), http.StatusUnauthorized) + return } + + sc.next.ServeHTTP(w, r) } func credsHandler(credsProvider aws.CredentialsProvider) http.HandlerFunc {