From 591a63b08a19fea11184c54818210ec4a9bd92dc Mon Sep 17 00:00:00 2001 From: Michael Tibben Date: Mon, 20 Apr 2020 14:47:29 +1000 Subject: [PATCH] Fix error handling in ecs server --- server/ec2.go | 2 +- server/ecs.go | 11 ++++++----- server/httplog.go | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/server/ec2.go b/server/ec2.go index 50a8e47d8..737c1c588 100644 --- a/server/ec2.go +++ b/server/ec2.go @@ -86,7 +86,7 @@ func startEc2CredentialsServer(creds *credentials.Credentials, region string) { router.HandleFunc("/latest/meta-data/iam/security-credentials/local-credentials", credsHandler(creds)) - log.Fatalln(http.ListenAndServe(ec2CredentialsServerAddr, logRequest(withLoopbackSecurityCheck(router)))) + log.Fatalln(http.ListenAndServe(ec2CredentialsServerAddr, withLogging(withLoopbackSecurityCheck(router)))) } // withLoopbackSecurityCheck is middleware to check that the request comes from the loopback device diff --git a/server/ecs.go b/server/ecs.go index e915c6e8e..46fcc1dd8 100644 --- a/server/ecs.go +++ b/server/ecs.go @@ -12,10 +12,11 @@ import ( "github.com/aws/aws-sdk-go/aws/credentials" ) -func writeErrorMessage(w http.ResponseWriter, msg string, status int) { - err := json.NewEncoder(w).Encode(map[string]string{"Message": msg}) - if err != nil { - http.Error(w, err.Error(), status) +func writeErrorMessage(w http.ResponseWriter, msg string, statusCode int) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(statusCode) + if err := json.NewEncoder(w).Encode(map[string]string{"Message": msg}); err != nil { + log.Println(err.Error()) } } @@ -41,7 +42,7 @@ func StartEcsCredentialServer(creds *credentials.Credentials) (string, string, e } go func() { - err := http.Serve(listener, logRequest(withAuthorizationCheck(token, ecsCredsHandler(creds)))) + err := http.Serve(listener, withLogging(withAuthorizationCheck(token, ecsCredsHandler(creds)))) // returns ErrServerClosed on graceful close if err != http.ErrServerClosed { log.Fatalf("ecs server: %s", err.Error()) diff --git a/server/httplog.go b/server/httplog.go index 1092b1848..35b4c8209 100644 --- a/server/httplog.go +++ b/server/httplog.go @@ -16,7 +16,7 @@ func (w *loggingMiddlewareResponseWriter) WriteHeader(statusCode int) { w.ResponseWriter.WriteHeader(statusCode) } -func logRequest(handler http.Handler) http.Handler { +func withLogging(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requestStart := time.Now() w2 := &loggingMiddlewareResponseWriter{w, http.StatusOK}