diff --git a/tools/goctl/api/gogen/genconfig.go b/tools/goctl/api/gogen/genconfig.go index a8e634cd..8410cfb9 100644 --- a/tools/goctl/api/gogen/genconfig.go +++ b/tools/goctl/api/gogen/genconfig.go @@ -8,6 +8,7 @@ import ( "github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/util" + "github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/vars" ) @@ -47,7 +48,12 @@ func genConfig(dir string, api *spec.ApiSpec) error { } var authImportStr = fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceUrl) - t := template.Must(template.New("configTemplate").Parse(configTemplate)) + text, err := templatex.LoadTemplate(category, configTemplateFile, configTemplate) + if err != nil { + return err + } + + t := template.Must(template.New("configTemplate").Parse(text)) buffer := new(bytes.Buffer) err = t.Execute(buffer, map[string]string{ "authImport": authImportStr, diff --git a/tools/goctl/api/gogen/genetc.go b/tools/goctl/api/gogen/genetc.go index 55082b3e..20dde425 100644 --- a/tools/goctl/api/gogen/genetc.go +++ b/tools/goctl/api/gogen/genetc.go @@ -8,6 +8,7 @@ import ( "github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/util" + "github.com/tal-tech/go-zero/tools/goctl/templatex" ) const ( @@ -39,7 +40,12 @@ func genEtc(dir string, api *spec.ApiSpec) error { port = strconv.Itoa(defaultPort) } - t := template.Must(template.New("etcTemplate").Parse(etcTemplate)) + text, err := templatex.LoadTemplate(category, etcTemplateFile, etcTemplate) + if err != nil { + return err + } + + t := template.Must(template.New("etcTemplate").Parse(text)) buffer := new(bytes.Buffer) err = t.Execute(buffer, map[string]string{ "serviceName": service.Name, diff --git a/tools/goctl/api/gogen/genhandlers.go b/tools/goctl/api/gogen/genhandlers.go index b30d04ce..31d02151 100644 --- a/tools/goctl/api/gogen/genhandlers.go +++ b/tools/goctl/api/gogen/genhandlers.go @@ -9,115 +9,76 @@ import ( "github.com/tal-tech/go-zero/tools/goctl/api/spec" apiutil "github.com/tal-tech/go-zero/tools/goctl/api/util" + "github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/vars" ) -const ( - handlerTemplate = `package handler +const handlerTemplate = `package handler import ( "net/http" - {{.importPackages}} + {{.ImportPackages}} ) -func {{.handlerName}}(ctx *svc.ServiceContext) http.HandlerFunc { +func {{.HandlerName}}(ctx *svc.ServiceContext) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - {{.handlerBody}} - } -} -` - handlerBodyTemplate = `{{.parseRequest}} - {{.processBody}} -` - parseRequestTemplate = `var req {{.requestType}} + var req types.{{.RequestType}} if err := httpx.Parse(r, &req); err != nil { httpx.Error(w, err) return } -` - hasRespTemplate = ` - l := logic.{{.logic}}(r.Context(), ctx) - {{.logicResponse}} l.{{.callee}}({{.req}}) + + l := logic.New{{.LogicType}}(r.Context(), ctx) + {{if .HasResp}}resp, {{end}}err := l.{{.Call}}(req) if err != nil { httpx.Error(w, err) } else { - {{.respWriter}} + {{if .HasResp}}httpx.OkJson(w, resp){{else}}httpx.Ok(w){{end}} } - ` -) + } +} +` + +type Handler struct { + ImportPackages string + HandlerName string + RequestType string + LogicType string + Call string + HasResp bool +} func genHandler(dir string, group spec.Group, route spec.Route) error { handler, ok := apiutil.GetAnnotationValue(route.Annotations, "server", "handler") if !ok { return fmt.Errorf("missing handler annotation for %q", route.Path) } - handler = getHandlerName(handler) - var reqBody string - if len(route.RequestType.Name) > 0 { - var bodyBuilder strings.Builder - t := template.Must(template.New("parseRequest").Parse(parseRequestTemplate)) - if err := t.Execute(&bodyBuilder, map[string]string{ - "requestType": typesPacket + "." + util.Title(route.RequestType.Name), - }); err != nil { - return err - } - reqBody = bodyBuilder.String() - } - var req = "req" - if len(route.RequestType.Name) == 0 { - req = "" - } - var logicResponse string - var writeResponse string - var respWriter = `httpx.WriteJson(w, http.StatusOK, resp)` - if len(route.ResponseType.Name) > 0 { - logicResponse = "resp, err :=" - writeResponse = "resp, err" - } else { - logicResponse = "err :=" - writeResponse = "nil, err" - respWriter = `httpx.Ok(w)` + handler = getHandlerName(handler) + if getHandlerFolderPath(group, route) != handlerDir { + handler = strings.Title(handler) } - var logicBodyBuilder strings.Builder - t := template.Must(template.New("hasRespTemplate").Parse(hasRespTemplate)) - if err := t.Execute(&logicBodyBuilder, map[string]string{ - "logic": "New" + strings.TrimSuffix(strings.Title(handler), "Handler") + "Logic", - "callee": strings.Title(strings.TrimSuffix(handler, "Handler")), - "req": req, - "logicResponse": logicResponse, - "writeResponse": writeResponse, - "respWriter": respWriter, - }); err != nil { + parentPkg, err := getParentPackage(dir) + if err != nil { return err } - respBody := logicBodyBuilder.String() - - if !strings.HasSuffix(handler, "Handler") { - handler = handler + "Handler" - } - var bodyBuilder strings.Builder - bodyTemplate := template.Must(template.New("handlerBodyTemplate").Parse(handlerBodyTemplate)) - if err := bodyTemplate.Execute(&bodyBuilder, map[string]string{ - "parseRequest": reqBody, - "processBody": respBody, - }); err != nil { - return err - } - return doGenToFile(dir, handler, group, route, bodyBuilder) + return doGenToFile(dir, handler, group, route, Handler{ + ImportPackages: genHandlerImports(group, route, parentPkg), + HandlerName: handler, + RequestType: util.Title(route.RequestType.Name), + LogicType: strings.TrimSuffix(strings.Title(handler), "Handler") + "Logic", + Call: strings.Title(strings.TrimSuffix(handler, "Handler")), + HasResp: len(route.ResponseType.Name) > 0, + }) } -func doGenToFile(dir, handler string, group spec.Group, route spec.Route, bodyBuilder strings.Builder) error { +func doGenToFile(dir, handler string, group spec.Group, route spec.Route, handleObj Handler) error { if getHandlerFolderPath(group, route) != handlerDir { handler = strings.Title(handler) } - parentPkg, err := getParentPackage(dir) - if err != nil { - return err - } filename := strings.ToLower(handler) if strings.HasSuffix(filename, "handler") { filename = filename + ".go" @@ -132,16 +93,18 @@ func doGenToFile(dir, handler string, group spec.Group, route spec.Route, bodyBu return nil } defer fp.Close() - t := template.Must(template.New("handlerTemplate").Parse(handlerTemplate)) + + text, err := templatex.LoadTemplate(category, handlerTemplateFile, handlerTemplate) + if err != nil { + return err + } + buffer := new(bytes.Buffer) - err = t.Execute(buffer, map[string]string{ - "importPackages": genHandlerImports(group, route, parentPkg), - "handlerName": handler, - "handlerBody": strings.TrimSpace(bodyBuilder.String()), - }) + err = template.Must(template.New("handlerTemplate").Parse(text)).Execute(buffer, handleObj) if err != nil { return nil } + formatCode := formatCode(buffer.String()) _, err = fp.WriteString(formatCode) return err diff --git a/tools/goctl/api/gogen/genlogic.go b/tools/goctl/api/gogen/genlogic.go index ff569de8..e27d3fd5 100644 --- a/tools/goctl/api/gogen/genlogic.go +++ b/tools/goctl/api/gogen/genlogic.go @@ -9,6 +9,7 @@ import ( "github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/util" + "github.com/tal-tech/go-zero/tools/goctl/templatex" ctlutil "github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/vars" ) @@ -93,7 +94,12 @@ func genLogicByRoute(dir string, group spec.Group, route spec.Route) error { requestString = "req " + "types." + strings.Title(route.RequestType.Name) } - t := template.Must(template.New("logicTemplate").Parse(logicTemplate)) + text, err := templatex.LoadTemplate(category, logicTemplateFile, logicTemplate) + if err != nil { + return err + } + + t := template.Must(template.New("logicTemplate").Parse(text)) buffer := new(bytes.Buffer) err = t.Execute(fp, map[string]string{ "imports": imports, diff --git a/tools/goctl/api/gogen/genmain.go b/tools/goctl/api/gogen/genmain.go index 546d82c3..76d486c0 100644 --- a/tools/goctl/api/gogen/genmain.go +++ b/tools/goctl/api/gogen/genmain.go @@ -8,6 +8,7 @@ import ( "github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/util" + "github.com/tal-tech/go-zero/tools/goctl/templatex" ctlutil "github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/vars" ) @@ -60,7 +61,12 @@ func genMain(dir string, api *spec.ApiSpec) error { return err } - t := template.Must(template.New("mainTemplate").Parse(mainTemplate)) + text, err := templatex.LoadTemplate(category, mainTemplateFile, mainTemplate) + if err != nil { + return err + } + + t := template.Must(template.New("mainTemplate").Parse(text)) buffer := new(bytes.Buffer) err = t.Execute(buffer, map[string]string{ "importPackages": genMainImports(parentPkg), diff --git a/tools/goctl/api/gogen/gensvc.go b/tools/goctl/api/gogen/gensvc.go index 6d30ac15..9efd9dd3 100644 --- a/tools/goctl/api/gogen/gensvc.go +++ b/tools/goctl/api/gogen/gensvc.go @@ -7,6 +7,7 @@ import ( "github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/util" + "github.com/tal-tech/go-zero/tools/goctl/templatex" ctlutil "github.com/tal-tech/go-zero/tools/goctl/util" ) @@ -46,8 +47,14 @@ func genServiceContext(dir string, api *spec.ApiSpec) error { if err != nil { return err } + + text, err := templatex.LoadTemplate(category, contextTemplateFile, contextTemplate) + if err != nil { + return err + } + var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\"" - t := template.Must(template.New("contextTemplate").Parse(contextTemplate)) + t := template.Must(template.New("contextTemplate").Parse(text)) buffer := new(bytes.Buffer) err = t.Execute(buffer, map[string]string{ "configImport": configImport, diff --git a/tools/goctl/api/gogen/template.go b/tools/goctl/api/gogen/template.go new file mode 100644 index 00000000..338e933a --- /dev/null +++ b/tools/goctl/api/gogen/template.go @@ -0,0 +1,29 @@ +package gogen + +import ( + "github.com/tal-tech/go-zero/tools/goctl/templatex" + "github.com/urfave/cli" +) + +const ( + category = "api" + configTemplateFile = "config.tpl" + contextTemplateFile = "context.tpl" + etcTemplateFile = "etc.tpl" + handlerTemplateFile = "handler.tpl" + logicTemplateFile = "logic.tpl" + mainTemplateFile = "main.tpl" +) + +var templates = map[string]string{ + configTemplateFile: configTemplate, + contextTemplateFile: contextTemplate, + etcTemplateFile: etcTemplate, + handlerTemplateFile: handlerTemplate, + logicTemplateFile: logicTemplate, + mainTemplateFile: mainTemplate, +} + +func GenTemplates(_ *cli.Context) error { + return templatex.InitTemplates(category, templates) +} diff --git a/tools/goctl/goctl.go b/tools/goctl/goctl.go index e51b09c3..316dec53 100644 --- a/tools/goctl/goctl.go +++ b/tools/goctl/goctl.go @@ -102,6 +102,13 @@ var ( }, }, Action: gogen.GoCommand, + Subcommands: []cli.Command{ + { + Name: "template", + Usage: "initialize the api templates", + Action: gogen.GenTemplates, + }, + }, }, { Name: "java", diff --git a/tools/goctl/model/sql/gen/delete.go b/tools/goctl/model/sql/gen/delete.go index d6c0701e..51e39211 100644 --- a/tools/goctl/model/sql/gen/delete.go +++ b/tools/goctl/model/sql/gen/delete.go @@ -5,7 +5,7 @@ import ( "github.com/tal-tech/go-zero/core/collection" "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" - "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util/stringx" ) @@ -22,7 +22,7 @@ func genDelete(table Table, withCache bool) (string, error) { } camel := table.Name.ToCamel() - output, err := util.With("delete"). + output, err := templatex.With("delete"). Parse(template.Delete). Execute(map[string]interface{}{ "upperStartCamelObject": camel, diff --git a/tools/goctl/model/sql/gen/field.go b/tools/goctl/model/sql/gen/field.go index 6c45c2c4..b29dcae7 100644 --- a/tools/goctl/model/sql/gen/field.go +++ b/tools/goctl/model/sql/gen/field.go @@ -5,7 +5,7 @@ import ( "github.com/tal-tech/go-zero/tools/goctl/model/sql/parser" "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" - "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/templatex" ) func genFields(fields []parser.Field) (string, error) { @@ -25,7 +25,7 @@ func genField(field parser.Field) (string, error) { if err != nil { return "", err } - output, err := util.With("types"). + output, err := templatex.With("types"). Parse(template.Field). Execute(map[string]interface{}{ "name": field.Name.ToCamel(), diff --git a/tools/goctl/model/sql/gen/findone.go b/tools/goctl/model/sql/gen/findone.go index 8f48f873..93ab2e88 100644 --- a/tools/goctl/model/sql/gen/findone.go +++ b/tools/goctl/model/sql/gen/findone.go @@ -2,13 +2,13 @@ package gen import ( "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" - "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util/stringx" ) func genFindOne(table Table, withCache bool) (string, error) { camel := table.Name.ToCamel() - output, err := util.With("findOne"). + output, err := templatex.With("findOne"). Parse(template.FindOne). Execute(map[string]interface{}{ "withCache": withCache, diff --git a/tools/goctl/model/sql/gen/findonebyfield.go b/tools/goctl/model/sql/gen/findonebyfield.go index 8bd1d7ba..d094436d 100644 --- a/tools/goctl/model/sql/gen/findonebyfield.go +++ b/tools/goctl/model/sql/gen/findonebyfield.go @@ -5,12 +5,12 @@ import ( "strings" "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" - "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util/stringx" ) func genFindOneByField(table Table, withCache bool) (string, string, error) { - t := util.With("findOneByField").Parse(template.FindOneByField) + t := templatex.With("findOneByField").Parse(template.FindOneByField) var list []string camelTableName := table.Name.ToCamel() for _, field := range table.Fields { @@ -36,7 +36,7 @@ func genFindOneByField(table Table, withCache bool) (string, string, error) { list = append(list, output.String()) } if withCache { - out, err := util.With("findOneByFieldExtraMethod").Parse(template.FindOneByFieldExtraMethod).Execute(map[string]interface{}{ + out, err := templatex.With("findOneByFieldExtraMethod").Parse(template.FindOneByFieldExtraMethod).Execute(map[string]interface{}{ "upperStartCamelObject": camelTableName, "primaryKeyLeft": table.CacheKey[table.PrimaryKey.Name.Source()].Left, "lowerStartCamelObject": stringx.From(camelTableName).UnTitle(), diff --git a/tools/goctl/model/sql/gen/gen.go b/tools/goctl/model/sql/gen/gen.go index f351676c..6a7f0c57 100644 --- a/tools/goctl/model/sql/gen/gen.go +++ b/tools/goctl/model/sql/gen/gen.go @@ -9,6 +9,7 @@ import ( "github.com/tal-tech/go-zero/tools/goctl/model/sql/parser" "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" + "github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util/console" "github.com/tal-tech/go-zero/tools/goctl/util/stringx" @@ -119,7 +120,7 @@ type ( ) func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, error) { - t := util.With("model"). + t := templatex.With("model"). Parse(template.Model). GoFmt(true) diff --git a/tools/goctl/model/sql/gen/imports.go b/tools/goctl/model/sql/gen/imports.go index 6d29ca56..c9c62f59 100644 --- a/tools/goctl/model/sql/gen/imports.go +++ b/tools/goctl/model/sql/gen/imports.go @@ -2,12 +2,12 @@ package gen import ( "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" - "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/templatex" ) func genImports(withCache, timeImport bool) (string, error) { if withCache { - buffer, err := util.With("import").Parse(template.Imports).Execute(map[string]interface{}{ + buffer, err := templatex.With("import").Parse(template.Imports).Execute(map[string]interface{}{ "time": timeImport, }) if err != nil { @@ -15,7 +15,7 @@ func genImports(withCache, timeImport bool) (string, error) { } return buffer.String(), nil } else { - buffer, err := util.With("import").Parse(template.ImportsNoCache).Execute(map[string]interface{}{ + buffer, err := templatex.With("import").Parse(template.ImportsNoCache).Execute(map[string]interface{}{ "time": timeImport, }) if err != nil { diff --git a/tools/goctl/model/sql/gen/insert.go b/tools/goctl/model/sql/gen/insert.go index b5600b9d..65341586 100644 --- a/tools/goctl/model/sql/gen/insert.go +++ b/tools/goctl/model/sql/gen/insert.go @@ -5,7 +5,7 @@ import ( "github.com/tal-tech/go-zero/core/collection" "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" - "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util/stringx" ) @@ -34,7 +34,7 @@ func genInsert(table Table, withCache bool) (string, error) { expressionValues = append(expressionValues, "data."+camel) } camel := table.Name.ToCamel() - output, err := util.With("insert"). + output, err := templatex.With("insert"). Parse(template.Insert). Execute(map[string]interface{}{ "withCache": withCache, diff --git a/tools/goctl/model/sql/gen/new.go b/tools/goctl/model/sql/gen/new.go index e5afe0ea..ab0cf7e6 100644 --- a/tools/goctl/model/sql/gen/new.go +++ b/tools/goctl/model/sql/gen/new.go @@ -2,11 +2,11 @@ package gen import ( "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" - "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/templatex" ) func genNew(table Table, withCache bool) (string, error) { - output, err := util.With("new"). + output, err := templatex.With("new"). Parse(template.New). Execute(map[string]interface{}{ "withCache": withCache, diff --git a/tools/goctl/model/sql/gen/tag.go b/tools/goctl/model/sql/gen/tag.go index 86f15c19..3d46afd8 100644 --- a/tools/goctl/model/sql/gen/tag.go +++ b/tools/goctl/model/sql/gen/tag.go @@ -2,14 +2,14 @@ package gen import ( "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" - "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/templatex" ) func genTag(in string) (string, error) { if in == "" { return in, nil } - output, err := util.With("tag"). + output, err := templatex.With("tag"). Parse(template.Tag). Execute(map[string]interface{}{ "field": in, diff --git a/tools/goctl/model/sql/gen/types.go b/tools/goctl/model/sql/gen/types.go index d6801e6b..24ac571d 100644 --- a/tools/goctl/model/sql/gen/types.go +++ b/tools/goctl/model/sql/gen/types.go @@ -2,7 +2,7 @@ package gen import ( "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" - "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/templatex" ) func genTypes(table Table, withCache bool) (string, error) { @@ -11,7 +11,7 @@ func genTypes(table Table, withCache bool) (string, error) { if err != nil { return "", err } - output, err := util.With("types"). + output, err := templatex.With("types"). Parse(template.Types). Execute(map[string]interface{}{ "withCache": withCache, diff --git a/tools/goctl/model/sql/gen/update.go b/tools/goctl/model/sql/gen/update.go index 426f8fa8..edc59aac 100644 --- a/tools/goctl/model/sql/gen/update.go +++ b/tools/goctl/model/sql/gen/update.go @@ -4,7 +4,7 @@ import ( "strings" "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" - "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util/stringx" ) @@ -22,7 +22,7 @@ func genUpdate(table Table, withCache bool) (string, error) { } expressionValues = append(expressionValues, "data."+table.PrimaryKey.Name.ToCamel()) camelTableName := table.Name.ToCamel() - output, err := util.With("update"). + output, err := templatex.With("update"). Parse(template.Update). Execute(map[string]interface{}{ "withCache": withCache, diff --git a/tools/goctl/model/sql/gen/vars.go b/tools/goctl/model/sql/gen/vars.go index 64e043e9..4bb79211 100644 --- a/tools/goctl/model/sql/gen/vars.go +++ b/tools/goctl/model/sql/gen/vars.go @@ -4,7 +4,7 @@ import ( "strings" "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" - "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util/stringx" ) @@ -14,7 +14,7 @@ func genVars(table Table, withCache bool) (string, error) { keys = append(keys, v.VarExpression) } camel := table.Name.ToCamel() - output, err := util.With("var"). + output, err := templatex.With("var"). Parse(template.Vars). GoFmt(true). Execute(map[string]interface{}{ diff --git a/tools/goctl/rpc/gen/gencall.go b/tools/goctl/rpc/gen/gencall.go index feaf534f..cf960022 100644 --- a/tools/goctl/rpc/gen/gencall.go +++ b/tools/goctl/rpc/gen/gencall.go @@ -7,6 +7,7 @@ import ( "github.com/tal-tech/go-zero/core/collection" "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" + "github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util" ) @@ -122,8 +123,8 @@ func (g *defaultRpcGenerator) genCall() error { } filename := filepath.Join(callPath, typesFilename) - head := util.GetHead(g.Ctx.ProtoSource) - err = util.With("types").GoFmt(true).Parse(callTemplateTypes).SaveTo(map[string]interface{}{ + head := templatex.GetHead(g.Ctx.ProtoSource) + err = templatex.With("types").GoFmt(true).Parse(callTemplateTypes).SaveTo(map[string]interface{}{ "head": head, "const": constLit, "filePackage": service.Name.Lower(), @@ -146,7 +147,7 @@ func (g *defaultRpcGenerator) genCall() error { return err } - err = util.With("shared").GoFmt(true).Parse(callTemplateText).SaveTo(map[string]interface{}{ + err = templatex.With("shared").GoFmt(true).Parse(callTemplateText).SaveTo(map[string]interface{}{ "name": service.Name.Lower(), "head": head, "filePackage": service.Name.Lower(), @@ -166,7 +167,7 @@ func (g *defaultRpcGenerator) genFunction(service *parser.RpcService) ([]string, imports.AddStr(fmt.Sprintf(`%v "%v"`, pkgName, g.mustGetPackage(dirPb))) for _, method := range service.Funcs { imports.AddStr(g.ast.Imports[method.ParameterIn.Package]) - buffer, err := util.With("sharedFn").Parse(callFunctionTemplate).Execute(map[string]interface{}{ + buffer, err := templatex.With("sharedFn").Parse(callFunctionTemplate).Execute(map[string]interface{}{ "rpcServiceName": service.Name.Title(), "method": method.Name.Title(), "package": pkgName, @@ -189,7 +190,7 @@ func (g *defaultRpcGenerator) getInterfaceFuncs(service *parser.RpcService) ([]s functions := make([]string, 0) for _, method := range service.Funcs { - buffer, err := util.With("interfaceFn").Parse(callInterfaceFunctionTemplate).Execute( + buffer, err := templatex.With("interfaceFn").Parse(callInterfaceFunctionTemplate).Execute( map[string]interface{}{ "hasComment": method.HaveDoc(), "comment": method.GetDoc(), diff --git a/tools/goctl/rpc/gen/genetc.go b/tools/goctl/rpc/gen/genetc.go index 4a655e42..d96ec85d 100644 --- a/tools/goctl/rpc/gen/genetc.go +++ b/tools/goctl/rpc/gen/genetc.go @@ -4,6 +4,7 @@ import ( "fmt" "path/filepath" + "github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util" ) @@ -22,7 +23,7 @@ func (g *defaultRpcGenerator) genEtc() error { return nil } - return util.With("etc").Parse(etcTemplate).SaveTo(map[string]interface{}{ + return templatex.With("etc").Parse(etcTemplate).SaveTo(map[string]interface{}{ "serviceName": g.Ctx.ServiceName.Lower(), }, fileName, false) } diff --git a/tools/goctl/rpc/gen/genlogic.go b/tools/goctl/rpc/gen/genlogic.go index b5e76435..c4527a95 100644 --- a/tools/goctl/rpc/gen/genlogic.go +++ b/tools/goctl/rpc/gen/genlogic.go @@ -7,6 +7,7 @@ import ( "github.com/tal-tech/go-zero/core/collection" "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" + "github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util" ) @@ -61,7 +62,7 @@ func (g *defaultRpcGenerator) genLogic() error { svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc)) imports.AddStr(svcImport) imports.AddStr(importList...) - err = util.With("logic").GoFmt(true).Parse(logicTemplate).SaveTo(map[string]interface{}{ + err = templatex.With("logic").GoFmt(true).Parse(logicTemplate).SaveTo(map[string]interface{}{ "logicName": fmt.Sprintf("%sLogic", method.Name.Title()), "functions": functions, "imports": strings.Join(imports.KeysStr(), util.NL), @@ -82,7 +83,7 @@ func (g *defaultRpcGenerator) genLogicFunction(packageName string, method *parse } imports.AddStr(g.ast.Imports[method.ParameterIn.Package]) imports.AddStr(g.ast.Imports[method.ParameterOut.Package]) - buffer, err := util.With("fun").Parse(logicFunctionTemplate).Execute(map[string]interface{}{ + buffer, err := templatex.With("fun").Parse(logicFunctionTemplate).Execute(map[string]interface{}{ "logicName": fmt.Sprintf("%sLogic", method.Name.Title()), "method": method.Name.Title(), "request": method.ParameterIn.StarExpression, diff --git a/tools/goctl/rpc/gen/genmain.go b/tools/goctl/rpc/gen/genmain.go index be5a26a6..00680e73 100644 --- a/tools/goctl/rpc/gen/genmain.go +++ b/tools/goctl/rpc/gen/genmain.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" + "github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util" ) @@ -58,8 +59,8 @@ func (g *defaultRpcGenerator) genMain() error { configImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirConfig)) imports = append(imports, configImport, pbImport, remoteImport, svcImport) srv, registers := g.genServer(pkg, file.Service) - head := util.GetHead(g.Ctx.ProtoSource) - return util.With("main").GoFmt(true).Parse(mainTemplate).SaveTo(map[string]interface{}{ + head := templatex.GetHead(g.Ctx.ProtoSource) + return templatex.With("main").GoFmt(true).Parse(mainTemplate).SaveTo(map[string]interface{}{ "head": head, "package": pkg, "serviceName": g.Ctx.ServiceName.Lower(), diff --git a/tools/goctl/rpc/gen/genserver.go b/tools/goctl/rpc/gen/genserver.go index c9d61782..c6051220 100644 --- a/tools/goctl/rpc/gen/genserver.go +++ b/tools/goctl/rpc/gen/genserver.go @@ -7,6 +7,7 @@ import ( "github.com/tal-tech/go-zero/core/collection" "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" + "github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util" ) @@ -51,7 +52,7 @@ func (g *defaultRpcGenerator) genHandler() error { imports := collection.NewSet() imports.AddStr(logicImport, svcImport) - head := util.GetHead(g.Ctx.ProtoSource) + head := templatex.GetHead(g.Ctx.ProtoSource) for _, service := range file.Service { filename := fmt.Sprintf("%vserver.go", service.Name.Lower()) serverFile := filepath.Join(serverPath, filename) @@ -60,7 +61,7 @@ func (g *defaultRpcGenerator) genHandler() error { return err } imports.AddStr(importList...) - err = util.With("server").GoFmt(true).Parse(serverTemplate).SaveTo(map[string]interface{}{ + err = templatex.With("server").GoFmt(true).Parse(serverTemplate).SaveTo(map[string]interface{}{ "head": head, "types": fmt.Sprintf(typeFmt, service.Name.Title()), "server": service.Name.Title(), @@ -85,7 +86,7 @@ func (g *defaultRpcGenerator) genFunctions(service *parser.RpcService) ([]string } imports.AddStr(g.ast.Imports[method.ParameterIn.Package]) imports.AddStr(g.ast.Imports[method.ParameterOut.Package]) - buffer, err := util.With("func").Parse(functionTemplate).Execute(map[string]interface{}{ + buffer, err := templatex.With("func").Parse(functionTemplate).Execute(map[string]interface{}{ "server": service.Name.Title(), "logicName": fmt.Sprintf("%sLogic", method.Name.Title()), "method": method.Name.Title(), diff --git a/tools/goctl/rpc/gen/gensvc.go b/tools/goctl/rpc/gen/gensvc.go index c2315813..1a0fa2cd 100644 --- a/tools/goctl/rpc/gen/gensvc.go +++ b/tools/goctl/rpc/gen/gensvc.go @@ -4,7 +4,7 @@ import ( "fmt" "path/filepath" - "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/templatex" ) const svcTemplate = `package svc @@ -25,7 +25,7 @@ func NewServiceContext(c config.Config) *ServiceContext { func (g *defaultRpcGenerator) genSvc() error { svcPath := g.dirM[dirSvc] fileName := filepath.Join(svcPath, fileServiceContext) - return util.With("svc").GoFmt(true).Parse(svcTemplate).SaveTo(map[string]interface{}{ + return templatex.With("svc").GoFmt(true).Parse(svcTemplate).SaveTo(map[string]interface{}{ "imports": fmt.Sprintf(`"%v"`, g.mustGetPackage(dirConfig)), }, fileName, false) } diff --git a/tools/goctl/rpc/gen/template.go b/tools/goctl/rpc/gen/template.go index d81c44a0..12954432 100644 --- a/tools/goctl/rpc/gen/template.go +++ b/tools/goctl/rpc/gen/template.go @@ -4,7 +4,7 @@ import ( "path/filepath" "strings" - "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util/console" "github.com/tal-tech/go-zero/tools/goctl/util/stringx" ) @@ -43,7 +43,7 @@ func (r *rpcTemplate) MustGenerate(showState bool) { r.Info("generating template...") protoFilename := filepath.Base(r.out) serviceName := stringx.From(strings.TrimSuffix(protoFilename, filepath.Ext(protoFilename))) - err := util.With("t").Parse(rpcTemplateText).SaveTo(map[string]string{ + err := templatex.With("t").Parse(rpcTemplateText).SaveTo(map[string]string{ "package": serviceName.UnTitle(), "serviceName": serviceName.Title(), }, r.out, false) diff --git a/tools/goctl/rpc/parser/pbast.go b/tools/goctl/rpc/parser/pbast.go index 6a5872f5..a57300b5 100644 --- a/tools/goctl/rpc/parser/pbast.go +++ b/tools/goctl/rpc/parser/pbast.go @@ -12,6 +12,7 @@ import ( "github.com/tal-tech/go-zero/core/lang" sx "github.com/tal-tech/go-zero/core/stringx" + "github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util/console" "github.com/tal-tech/go-zero/tools/goctl/util/stringx" @@ -589,7 +590,7 @@ func (a *PbAst) GenTypesCode() (string, error) { types = append(types, typeCode) } - buffer, err := util.With("type").Parse(typeTemplate).Execute(map[string]interface{}{ + buffer, err := templatex.With("type").Parse(typeTemplate).Execute(map[string]interface{}{ "types": strings.Join(types, util.NL+util.NL), }) if err != nil { @@ -614,7 +615,7 @@ func (s *Struct) genCode(containsTypeStatement bool) (string, error) { comment = f.Comment[0] } doc = strings.Join(f.Document, util.NL) - buffer, err := util.With(sx.Rand()).Parse(fieldTemplate).Execute(map[string]interface{}{ + buffer, err := templatex.With(sx.Rand()).Parse(fieldTemplate).Execute(map[string]interface{}{ "name": f.Name.Title(), "type": f.Type.InvokeTypeExpression, "tag": f.JsonTag, @@ -629,7 +630,7 @@ func (s *Struct) genCode(containsTypeStatement bool) (string, error) { fields = append(fields, buffer.String()) } - buffer, err := util.With("struct").Parse(structTemplate).Execute(map[string]interface{}{ + buffer, err := templatex.With("struct").Parse(structTemplate).Execute(map[string]interface{}{ "type": containsTypeStatement, "name": s.Name.Title(), "fields": strings.Join(fields, util.NL), diff --git a/tools/goctl/rpc/parser/proto.go b/tools/goctl/rpc/parser/proto.go index e0e6f947..8632d7e0 100644 --- a/tools/goctl/rpc/parser/proto.go +++ b/tools/goctl/rpc/parser/proto.go @@ -10,6 +10,7 @@ import ( "github.com/emicklei/proto" "github.com/tal-tech/go-zero/core/collection" "github.com/tal-tech/go-zero/core/lang" + "github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util/stringx" ) @@ -262,7 +263,7 @@ func (e *Enum) GenEnumCode() (string, error) { } element = append(element, code) } - buffer, err := util.With("enum").Parse(enumTemplate).Execute(map[string]interface{}{ + buffer, err := templatex.With("enum").Parse(enumTemplate).Execute(map[string]interface{}{ "element": strings.Join(element, util.NL), }) if err != nil { @@ -272,7 +273,7 @@ func (e *Enum) GenEnumCode() (string, error) { } func (e *Enum) GenEnumTypeCode() (string, error) { - buffer, err := util.With("enumAlias").Parse(enumTypeTemplate).Execute(map[string]interface{}{ + buffer, err := templatex.With("enumAlias").Parse(enumTypeTemplate).Execute(map[string]interface{}{ "name": e.Name.Source(), }) if err != nil { @@ -282,7 +283,7 @@ func (e *Enum) GenEnumTypeCode() (string, error) { } func (e *EnumField) GenEnumFieldCode(parentName string) (string, error) { - buffer, err := util.With("enumField").Parse(enumFiledTemplate).Execute(map[string]interface{}{ + buffer, err := templatex.With("enumField").Parse(enumFiledTemplate).Execute(map[string]interface{}{ "key": e.Key, "name": parentName, "value": e.Value, diff --git a/tools/goctl/templatex/files.go b/tools/goctl/templatex/files.go new file mode 100644 index 00000000..0f8e825a --- /dev/null +++ b/tools/goctl/templatex/files.go @@ -0,0 +1,79 @@ +package templatex + +import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" + + "github.com/logrusorgru/aurora" + "github.com/tal-tech/go-zero/tools/goctl/util" +) + +const goctlDir = ".goctl" + +func InitTemplates(category string, templates map[string]string) error { + dir, err := getTemplateDir(category) + if err != nil { + return err + } + + if err := util.MkdirIfNotExist(dir); err != nil { + return err + } + + for k, v := range templates { + if err := createTemplate(filepath.Join(dir, k), v); err != nil { + return err + } + } + + fmt.Printf("Templates are generated in %s, %s\n", aurora.Green(dir), + aurora.Red("edit on your risk!")) + + return nil +} + +func LoadTemplate(category, file, builtin string) (string, error) { + dir, err := getTemplateDir(category) + if err != nil { + return "", err + } + + file = filepath.Join(dir, file) + if !util.FileExists(file) { + return builtin, nil + } + + content, err := ioutil.ReadFile(file) + if err != nil { + return "", err + } + + return string(content), nil +} + +func createTemplate(file, content string) error { + if util.FileExists(file) { + println(1) + return nil + } + + f, err := os.Create(file) + if err != nil { + return err + } + defer f.Close() + + _, err = f.WriteString(content) + return err +} + +func getTemplateDir(category string) (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + + return filepath.Join(home, goctlDir, category), nil +} diff --git a/tools/goctl/util/head.go b/tools/goctl/templatex/head.go similarity index 93% rename from tools/goctl/util/head.go rename to tools/goctl/templatex/head.go index b7534c71..ba44dcc0 100644 --- a/tools/goctl/util/head.go +++ b/tools/goctl/templatex/head.go @@ -1,4 +1,4 @@ -package util +package templatex var headTemplate = `// Code generated by goctl. DO NOT EDIT! // Source: {{.source}}` diff --git a/tools/goctl/util/templatex.go b/tools/goctl/templatex/templatex.go similarity index 68% rename from tools/goctl/util/templatex.go rename to tools/goctl/templatex/templatex.go index 0e623f80..3abbaecb 100644 --- a/tools/goctl/util/templatex.go +++ b/tools/goctl/templatex/templatex.go @@ -1,22 +1,23 @@ -package util +package templatex import ( "bytes" goformat "go/format" "io/ioutil" - "os" "text/template" -) -type ( - defaultTemplate struct { - name string - text string - goFmt bool - savePath string - } + "github.com/tal-tech/go-zero/tools/goctl/util" ) +const regularPerm = 0666 + +type defaultTemplate struct { + name string + text string + goFmt bool + savePath string +} + func With(name string) *defaultTemplate { return &defaultTemplate{ name: name, @@ -33,37 +34,38 @@ func (t *defaultTemplate) GoFmt(format bool) *defaultTemplate { } func (t *defaultTemplate) SaveTo(data interface{}, path string, forceUpdate bool) error { - if FileExists(path) && !forceUpdate { + if util.FileExists(path) && !forceUpdate { return nil } - output, err := t.execute(data) + + output, err := t.Execute(data) if err != nil { return err } - return ioutil.WriteFile(path, output.Bytes(), os.ModePerm) -} -func (t *defaultTemplate) Execute(data interface{}) (*bytes.Buffer, error) { - return t.execute(data) + return ioutil.WriteFile(path, output.Bytes(), regularPerm) } -func (t *defaultTemplate) execute(data interface{}) (*bytes.Buffer, error) { +func (t *defaultTemplate) Execute(data interface{}) (*bytes.Buffer, error) { tem, err := template.New(t.name).Parse(t.text) if err != nil { return nil, err } + buf := new(bytes.Buffer) - err = tem.Execute(buf, data) - if err != nil { + if err = tem.Execute(buf, data); err != nil { return nil, err } + if !t.goFmt { return buf, nil } + formatOutput, err := goformat.Source(buf.Bytes()) if err != nil { return nil, err } + buf.Reset() buf.Write(formatOutput) return buf, nil