From 16ba0fd8a6ff7392feb4e0d210d55aa58361bb6e Mon Sep 17 00:00:00 2001 From: Ali Nehzat Date: Thu, 9 Nov 2023 11:47:04 +1100 Subject: [PATCH] return url params in options handler --- cmd/happy/main.go | 17 ++++++++++------- testdata/main_api.go | 19 ++++++++++++------- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/cmd/happy/main.go b/cmd/happy/main.go index d0cc517..cee7722 100644 --- a/cmd/happy/main.go +++ b/cmd/happy/main.go @@ -163,7 +163,10 @@ func generateHandler(gctx *genContext, eps []endpoint, tree *tree) error { w.L("return nil") return } - w.L("return %#v", ep.directive.options) + w.L("out := %#v", ep.directive.options) + w.L("// Merge url params into options") + w.L("for k, v := range params { out[k] = v }") + w.L("return out") }) w.L("return nil") }) @@ -498,7 +501,7 @@ func genEndpoint(gctx *genContext, w *codewriter.Writer, ep endpoint) error { case implements(param, textUnmarshalerInterface()): paramName := fmt.Sprintf("param%d", index) w.L("var %s %s", paramName, ref) - w.L("if err := %s.UnmarshalText([]byte(params[%d])); err != nil {", paramName, index) + w.L(`if err := %s.UnmarshalText([]byte(params[":%s"])); err != nil {`, paramName, param.Name()) w.L(" http.Error(w, \"%s: \" + err.Error(), http.StatusBadRequest)", param.Name()) w.L(" return") w.L("}") @@ -506,15 +509,15 @@ func genEndpoint(gctx *genContext, w *codewriter.Writer, ep endpoint) error { case bt == "string": if bt != ref { - args = append(args, fmt.Sprintf("%s(params[%d])", ref, index)) + args = append(args, fmt.Sprintf(`%s(params[":%s"])`, ref, param.Name())) } else { - args = append(args, fmt.Sprintf("params[%d]", index)) + args = append(args, fmt.Sprintf(`params[":%s"]`, param.Name())) } case bt == "int": paramName := fmt.Sprintf("param%d", index) w.L("var %s int", paramName) - w.L("%s, err = strconv.Atoi(params[%d])", paramName, index) + w.L(`%s, err = strconv.Atoi(params[":%s"])`, paramName, param.Name()) if bt != ref { args = append(args, fmt.Sprintf("%s(%s)", ref, paramName)) } else { @@ -618,7 +621,7 @@ func (t *tree) Write(w *codewriter.Writer, earlyExit string, visitor func(w *cod w.L(` parts = append(parts, p)`) w.L(` }`) w.L(`}`) - w.L(`var params []string`) + w.L(`var params map[string]string = map[string]string{}`) w.L(`_ = params`) w.L(`switch parts[0] {`) t.recursiveWrite(w, 0, visitor, earlyExit) @@ -632,7 +635,7 @@ func (t *tree) recursiveWrite(w *codewriter.Writer, n int, visitor func(w *codew w.L(`case "%s":`, t.part) } else { w.L(`default: // Parameter %s`, t.part) - w.L(` params = append(params, parts[%d])`, n) + w.L(` params["%s"] = parts[%d]`, t.part, n) } w.In(func(w *codewriter.Writer) { w.In(func(w *codewriter.Writer) { diff --git a/testdata/main_api.go b/testdata/main_api.go index f60ee98..3853c2a 100644 --- a/testdata/main_api.go +++ b/testdata/main_api.go @@ -18,7 +18,7 @@ func (h *Service) HandlerOptions(r *http.Request) map[string]string { parts = append(parts, p) } } - var params []string + var params map[string]string = map[string]string{} _ = params switch parts[0] { case "": @@ -30,7 +30,12 @@ func (h *Service) HandlerOptions(r *http.Request) map[string]string { if len(parts) == 2 { switch r.Method { // Leaf case "POST": - return map[string]string{"authenticated": ""} + out := map[string]string{"authenticated": ""} + // Merge url params into options + for k, v := range params { + out[k] = v + } + return out } return nil } @@ -49,7 +54,7 @@ func (h *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { parts = append(parts, p) } } - var params []string + var params map[string]string = map[string]string{} _ = params switch parts[0] { case "": @@ -81,12 +86,12 @@ func (h *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } switch parts[2] { default: // Parameter :id - params = append(params, parts[2]) + params[":id"] = parts[2] if len(parts) == 3 { switch r.Method { // Leaf case "GET": var param0 ID - if err := param0.UnmarshalText([]byte(params[0])); err != nil { + if err := param0.UnmarshalText([]byte(params[":id"])); err != nil { http.Error(w, "id: "+err.Error(), http.StatusBadRequest) return } @@ -101,7 +106,7 @@ func (h *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch r.Method { // Leaf case "GET": var param0 ID - if err := param0.UnmarshalText([]byte(params[0])); err != nil { + if err := param0.UnmarshalText([]byte(params[":id"])); err != nil { http.Error(w, "id: "+err.Error(), http.StatusBadRequest) return } @@ -170,7 +175,7 @@ matched: http.Error(w, `failed to encode response: `+err.Error(), http.StatusInternalServerError) return } - w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Header().Set("Content-Length", strconv.Itoa(len(data))) w.WriteHeader(http.StatusOK) w.Write(data)