diff --git a/tools/goctl/api/gogen/genmiddleware.go b/tools/goctl/api/gogen/genmiddleware.go new file mode 100644 index 00000000..642a039c --- /dev/null +++ b/tools/goctl/api/gogen/genmiddleware.go @@ -0,0 +1,59 @@ +package gogen + +import ( + "bytes" + "strings" + "text/template" + + "github.com/tal-tech/go-zero/tools/goctl/api/util" +) + +var middlewareImplementCode = ` +package middleware + +import "net/http" + +type {{.name}} struct { +} + +func New{{.name}}() *{{.name}} { + return &{{.name}}{} +} + +func (m *{{.name}})Handle(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // TODO generate middleware implement function, delete after code implementation + + // Passthrough to next handler if need + next(w, r) + } +} +` + +func genMiddleware(dir string, middlewares []string) error { + for _, item := range middlewares { + filename := strings.TrimSuffix(strings.ToLower(item), "middleware") + "middleware" + ".go" + fp, created, err := util.MaybeCreateFile(dir, middlewareDir, filename) + if err != nil { + return err + } + if !created { + return nil + } + defer fp.Close() + + name := strings.TrimSuffix(item, "Middleware") + "Middleware" + t := template.Must(template.New("contextTemplate").Parse(middlewareImplementCode)) + buffer := new(bytes.Buffer) + err = t.Execute(buffer, map[string]string{ + "name": strings.Title(name), + }) + if err != nil { + return nil + } + formatCode := formatCode(buffer.String()) + _, err = fp.WriteString(formatCode) + return err + } + return nil +} diff --git a/tools/goctl/api/gogen/gensvc.go b/tools/goctl/api/gogen/gensvc.go index b7b83882..7e7176e9 100644 --- a/tools/goctl/api/gogen/gensvc.go +++ b/tools/goctl/api/gogen/gensvc.go @@ -3,6 +3,7 @@ package gogen import ( "bytes" "fmt" + "strings" "text/template" "github.com/tal-tech/go-zero/tools/goctl/api/spec" @@ -31,14 +32,6 @@ func NewServiceContext(c {{.config}}) *ServiceContext { } } -{{.middlewareImplement}} -` - middlewareImplementCode = `func %s(next http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - // TODO generate middleware implement function, delete after code implementation - } -} - ` ) @@ -70,16 +63,21 @@ func genServiceContext(dir string, api *spec.ApiSpec) error { var middlewareStr string var middlewareAssignment string - var middlewareImplement string - for _, item := range getMiddleware(api) { + var middlewares = getMiddleware(api) + err = genMiddleware(dir, middlewares) + if err != nil { + return err + } + + for _, item := range middlewares { middlewareStr += fmt.Sprintf("%s rest.Middleware\n", item) - middlewareAssignment += fmt.Sprintf("%s: %s,\n", item, item) - middlewareImplement += fmt.Sprintf(middlewareImplementCode, item) + name := strings.TrimSuffix(item, "Middleware") + "Middleware" + middlewareAssignment += fmt.Sprintf("%s: %s,\n", item, fmt.Sprintf("middleware.New%s().%s", strings.Title(name), "Handle")) } var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\"" if len(middlewareStr) > 0 { - configImport += "\n\t\"net/http\"" + configImport += "\n\t\"" + ctlutil.JoinPackages(parentPkg, middlewareDir) + "\"" configImport += fmt.Sprintf("\n\t\"%s/rest\"", vars.ProjectOpenSourceUrl) } @@ -90,7 +88,6 @@ func genServiceContext(dir string, api *spec.ApiSpec) error { "config": "config.Config", "middleware": middlewareStr, "middlewareAssignment": middlewareAssignment, - "middlewareImplement": middlewareImplement, }) if err != nil { return nil diff --git a/tools/goctl/api/gogen/vars.go b/tools/goctl/api/gogen/vars.go index 6a43fc83..f7f94cf7 100644 --- a/tools/goctl/api/gogen/vars.go +++ b/tools/goctl/api/gogen/vars.go @@ -7,6 +7,7 @@ const ( contextDir = interval + "svc" handlerDir = interval + "handler" logicDir = interval + "logic" + middlewareDir = interval + "middleware" typesDir = interval + typesPacket groupProperty = "group" )