From 7daf3aa7d479e48efe8e3854b01cea47e603528d Mon Sep 17 00:00:00 2001 From: Michael Tibben Date: Mon, 13 Jan 2020 21:31:11 +1100 Subject: [PATCH] Don't use exec syscall when starting the server --- cli/exec.go | 63 ++++++++++++++++++++++++++++++++++++++++++++- cli/exec_default.go | 42 ------------------------------ cli/exec_unix.go | 21 --------------- server/server.go | 11 +++----- 4 files changed, 66 insertions(+), 71 deletions(-) delete mode 100644 cli/exec_default.go delete mode 100644 cli/exec_unix.go diff --git a/cli/exec.go b/cli/exec.go index 70e5f81e7..10e7d4c75 100644 --- a/cli/exec.go +++ b/cli/exec.go @@ -5,7 +5,11 @@ import ( "fmt" "log" "os" + "os/exec" + "os/signal" + "runtime" "strings" + "syscall" "time" "github.com/99designs/aws-vault/server" @@ -179,7 +183,12 @@ func ExecCommand(input ExecCommandInput) error { } } - err = exec(input.Command, input.Args, env) + if input.StartServer { + err = execCmd(input.Command, input.Args, env) + } else { + err = execSyscall(input.Command, input.Args, env) + } + if err != nil { return fmt.Errorf("Error execing process: %w", err) } @@ -207,3 +216,55 @@ func (e *environ) Set(key, val string) { e.Unset(key) *e = append(*e, key+"="+val) } + +func execCmd(command string, args []string, env []string) error { + cmd := exec.Command(command, args...) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Env = env + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan) + + if err := cmd.Start(); err != nil { + return fmt.Errorf("Failed to start command: %v", err) + } + + go func() { + for { + sig := <-sigChan + cmd.Process.Signal(sig) + } + }() + + if err := cmd.Wait(); err != nil { + cmd.Process.Signal(os.Kill) + return fmt.Errorf("Failed to wait for command termination: %v", err) + } + + waitStatus := cmd.ProcessState.Sys().(syscall.WaitStatus) + os.Exit(waitStatus.ExitStatus()) + return nil +} + +func supportsExecSyscall() bool { + return runtime.GOOS == "linux" || runtime.GOOS == "darwin" || runtime.GOOS == "freebsd" +} + +func execSyscall(command string, args []string, env []string) error { + if !supportsExecSyscall() { + return execCmd(command, args, env) + } + + argv0, err := exec.LookPath(command) + if err != nil { + return err + } + + argv := make([]string, 0, 1+len(args)) + argv = append(argv, command) + argv = append(argv, args...) + + return syscall.Exec(argv0, argv, env) +} diff --git a/cli/exec_default.go b/cli/exec_default.go deleted file mode 100644 index 47b2a629e..000000000 --- a/cli/exec_default.go +++ /dev/null @@ -1,42 +0,0 @@ -// +build !linux,!darwin,!freebsd - -package cli - -import ( - "fmt" - "os" - osexec "os/exec" - "os/signal" - "syscall" -) - -func exec(command string, args []string, env []string) error { - cmd := osexec.Command(command, args...) - cmd.Stdin = os.Stdin - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - cmd.Env = env - - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan) - - if err := cmd.Start(); err != nil { - return fmt.Errorf("Failed to start command: %v", err) - } - - go func() { - for { - sig := <-sigChan - cmd.Process.Signal(sig) - } - }() - - if err := cmd.Wait(); err != nil { - cmd.Process.Signal(os.Kill) - return fmt.Errorf("Failed to wait for command termination: %v", err) - } - - waitStatus := cmd.ProcessState.Sys().(syscall.WaitStatus) - os.Exit(waitStatus.ExitStatus()) - return nil -} diff --git a/cli/exec_unix.go b/cli/exec_unix.go deleted file mode 100644 index 9b70fa452..000000000 --- a/cli/exec_unix.go +++ /dev/null @@ -1,21 +0,0 @@ -// +build linux darwin freebsd - -package cli - -import ( - osexec "os/exec" - "syscall" -) - -func exec(command string, args []string, env []string) error { - argv0, err := osexec.LookPath(command) - if err != nil { - return err - } - - argv := make([]string, 0, 1+len(args)) - argv = append(argv, command) - argv = append(argv, args...) - - return syscall.Exec(argv0, argv, env) -} diff --git a/server/server.go b/server/server.go index 5b274cc0d..40d23416d 100644 --- a/server/server.go +++ b/server/server.go @@ -85,13 +85,10 @@ func StartCredentialsServer(creds *credentials.Credentials) error { } } - l, err := net.Listen("tcp", localServerBind) - if err != nil { - return err - } - - log.Printf("Local instance role server running on %s", l.Addr()) - go http.Serve(l, credsHandler(creds)) + log.Printf("Starting local instance role server on %s", localServerBind) + go func() { + log.Fatalln(http.ListenAndServe(localServerBind, credsHandler(creds))) + }() return nil }