Code optimized (#493)

master
kingxt 4 years ago committed by GitHub
parent 059027bc9d
commit f98c9246b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -39,6 +39,7 @@ service {{.serviceName}} {
}
`
// ApiCommand create api template file
func ApiCommand(c *cli.Context) error {
apiFile := c.String("o")
if len(apiFile) == 0 {

@ -9,6 +9,7 @@ import (
"github.com/urfave/cli"
)
// DartCommand create dart network request code
func DartCommand(c *cli.Context) error {
apiFile := c.String("api")
dir := c.String("dir")

@ -2,9 +2,9 @@ package dartgen
import (
"os"
"reflect"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/api/util"
)
@ -32,10 +32,18 @@ func pathToFuncName(path string) string {
return util.ToLower(camel[:1]) + camel[1:]
}
func tagGet(tag, k string) (reflect.Value, error) {
v, _ := util.TagLookup(tag, k)
out := strings.Split(v, ",")[0]
return reflect.ValueOf(out), nil
func tagGet(tag, k string) string {
tags, err := spec.Parse(tag)
if err != nil {
panic(k + " not exist")
}
v, err := tags.Get(k)
if err != nil {
panic(k + " value not exist")
}
return v.Name
}
func isDirectType(s string) bool {

@ -12,6 +12,7 @@ import (
"github.com/urfave/cli"
)
// DocCommand generate markdown doc file
func DocCommand(c *cli.Context) error {
dir := c.String("dir")
if len(dir) == 0 {

@ -160,6 +160,7 @@ type Response struct {
@server(
jwt: Auth
signature: true
)
service A-api {
@handler GreetHandler

@ -40,7 +40,7 @@ func genConfig(dir string, cfg *config.Config, api *spec.ApiSpec) error {
for _, item := range authNames {
auths = append(auths, fmt.Sprintf("%s %s", item, jwtTemplate))
}
var authImportStr = fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceUrl)
var authImportStr = fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceURL)
return genFile(fileGenConfig{
dir: dir,

@ -109,7 +109,7 @@ func genHandlerImports(group spec.Group, route spec.Route, parentPkg string) str
if len(route.RequestTypeName()) > 0 {
imports = append(imports, fmt.Sprintf("\"%s\"\n", util.JoinPackages(parentPkg, typesDir)))
}
imports = append(imports, fmt.Sprintf("\"%s/rest/httpx\"", vars.ProjectOpenSourceUrl))
imports = append(imports, fmt.Sprintf("\"%s/rest/httpx\"", vars.ProjectOpenSourceURL))
return strings.Join(imports, "\n\t")
}

@ -122,6 +122,6 @@ func genLogicImports(route spec.Route, parentPkg string) string {
if len(route.ResponseTypeName()) > 0 || len(route.RequestTypeName()) > 0 {
imports = append(imports, fmt.Sprintf("\"%s\"\n", ctlutil.JoinPackages(parentPkg, typesDir)))
}
imports = append(imports, fmt.Sprintf("\"%s/core/logx\"", vars.ProjectOpenSourceUrl))
imports = append(imports, fmt.Sprintf("\"%s/core/logx\"", vars.ProjectOpenSourceURL))
return strings.Join(imports, "\n\t")
}

@ -74,7 +74,7 @@ func genMainImports(parentPkg string) string {
imports = append(imports, fmt.Sprintf("\"%s\"", ctlutil.JoinPackages(parentPkg, configDir)))
imports = append(imports, fmt.Sprintf("\"%s\"", ctlutil.JoinPackages(parentPkg, handlerDir)))
imports = append(imports, fmt.Sprintf("\"%s\"\n", ctlutil.JoinPackages(parentPkg, contextDir)))
imports = append(imports, fmt.Sprintf("\"%s/core/conf\"", vars.ProjectOpenSourceUrl))
imports = append(imports, fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceUrl))
imports = append(imports, fmt.Sprintf("\"%s/core/conf\"", vars.ProjectOpenSourceURL))
imports = append(imports, fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceURL))
return strings.Join(imports, "\n\t")
}

@ -89,7 +89,7 @@ func genRoutes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
}
var signature string
if g.signatureEnabled {
signature = fmt.Sprintf("\n rest.WithSignature(serverCtx.Config.%s.Signature),", g.authName)
signature = "\n rest.WithSignature(serverCtx.Config.Signature),"
}
var routes string
@ -163,7 +163,7 @@ func genRouteImports(parentPkg string, api *spec.ApiSpec) string {
imports := importSet.KeysStr()
sort.Strings(imports)
projectSection := strings.Join(imports, "\n\t")
depSection := fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceUrl)
depSection := fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceURL)
return fmt.Sprintf("%s\n\n\t%s", projectSection, depSection)
}
@ -196,6 +196,10 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) {
groupedRoutes.authName = jwt
groupedRoutes.jwtEnabled = true
}
signature := g.GetAnnotation("signature")
if signature == "true" {
groupedRoutes.signatureEnabled = true
}
middleware := g.GetAnnotation("middleware")
if len(middleware) > 0 {
for _, item := range strings.Split(middleware, ",") {

@ -64,7 +64,7 @@ func genServiceContext(dir string, cfg *config.Config, api *spec.ApiSpec) error
var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\""
if len(middlewareStr) > 0 {
configImport += "\n\t\"" + ctlutil.JoinPackages(parentPkg, middlewareDir) + "\""
configImport += fmt.Sprintf("\n\t\"%s/rest\"", vars.ProjectOpenSourceUrl)
configImport += fmt.Sprintf("\n\t\"%s/rest\"", vars.ProjectOpenSourceURL)
}
return genFile(fileGenConfig{

@ -94,10 +94,6 @@ func getAuths(api *spec.ApiSpec) []string {
if len(jwt) > 0 {
authNames.Add(jwt)
}
signature := g.GetAnnotation("signature")
if len(signature) > 0 {
authNames.Add(signature)
}
}
return authNames.KeysStr()
}

@ -22,12 +22,6 @@ type Api struct {
}
func (v *ApiVisitor) VisitApi(ctx *api.ApiContext) interface{} {
defer func() {
if p := recover(); p != nil {
panic(fmt.Errorf("%+v", p))
}
}()
var final Api
final.importM = map[string]PlaceHolder{}
final.typeM = map[string]PlaceHolder{}
@ -36,107 +30,126 @@ func (v *ApiVisitor) VisitApi(ctx *api.ApiContext) interface{} {
final.routeM = map[string]PlaceHolder{}
for _, each := range ctx.AllSpec() {
root := each.Accept(v).(*Api)
if root.Syntax != nil {
if final.Syntax != nil {
v.panic(root.Syntax.Syntax, fmt.Sprintf("mutiple syntax declaration"))
}
final.Syntax = root.Syntax
}
v.acceptSyntax(root, &final)
v.accetpImport(root, &final)
v.acceptInfo(root, &final)
v.acceptType(root, &final)
v.acceptService(root, &final)
}
for _, imp := range root.Import {
if _, ok := final.importM[imp.Value.Text()]; ok {
v.panic(imp.Import, fmt.Sprintf("duplicate import '%s'", imp.Value.Text()))
}
return &final
}
final.importM[imp.Value.Text()] = Holder
final.Import = append(final.Import, imp)
func (v *ApiVisitor) acceptService(root *Api, final *Api) {
for _, service := range root.Service {
if _, ok := final.serviceM[service.ServiceApi.Name.Text()]; !ok && len(final.serviceM) > 0 {
v.panic(service.ServiceApi.Name, fmt.Sprintf("mutiple service declaration"))
}
v.duplicateServerItemCheck(service)
if root.Info != nil {
infoM := map[string]PlaceHolder{}
if final.Info != nil {
v.panic(root.Info.Info, fmt.Sprintf("mutiple info declaration"))
for _, route := range service.ServiceApi.ServiceRoute {
uniqueRoute := fmt.Sprintf("%s %s", route.Route.Method.Text(), route.Route.Path.Text())
if _, ok := final.routeM[uniqueRoute]; ok {
v.panic(route.Route.Method, fmt.Sprintf("duplicate route '%s'", uniqueRoute))
}
for _, value := range root.Info.Kvs {
if _, ok := infoM[value.Key.Text()]; ok {
v.panic(value.Key, fmt.Sprintf("duplicate key '%s'", value.Key.Text()))
final.routeM[uniqueRoute] = Holder
var handlerExpr Expr
if route.AtServer != nil {
atServerM := map[string]PlaceHolder{}
for _, kv := range route.AtServer.Kv {
if _, ok := atServerM[kv.Key.Text()]; ok {
v.panic(kv.Key, fmt.Sprintf("duplicate key '%s'", kv.Key.Text()))
}
atServerM[kv.Key.Text()] = Holder
if kv.Key.Text() == "handler" {
handlerExpr = kv.Value
}
}
infoM[value.Key.Text()] = Holder
}
final.Info = root.Info
}
for _, tp := range root.Type {
if _, ok := final.typeM[tp.NameExpr().Text()]; ok {
v.panic(tp.NameExpr(), fmt.Sprintf("duplicate type '%s'", tp.NameExpr().Text()))
if route.AtHandler != nil {
handlerExpr = route.AtHandler.Name
}
final.typeM[tp.NameExpr().Text()] = Holder
final.Type = append(final.Type, tp)
}
for _, service := range root.Service {
if _, ok := final.serviceM[service.ServiceApi.Name.Text()]; !ok && len(final.serviceM) > 0 {
v.panic(service.ServiceApi.Name, fmt.Sprintf("mutiple service declaration"))
if handlerExpr == nil {
v.panic(route.Route.Method, fmt.Sprintf("mismtached handler"))
}
if service.AtServer != nil {
atServerM := map[string]PlaceHolder{}
for _, kv := range service.AtServer.Kv {
if _, ok := atServerM[kv.Key.Text()]; ok {
v.panic(kv.Key, fmt.Sprintf("duplicate key '%s'", kv.Key.Text()))
}
if handlerExpr.Text() == "" {
v.panic(handlerExpr, fmt.Sprintf("mismtached handler"))
}
atServerM[kv.Key.Text()] = Holder
}
if _, ok := final.handlerM[handlerExpr.Text()]; ok {
v.panic(handlerExpr, fmt.Sprintf("duplicate handler '%s'", handlerExpr.Text()))
}
final.handlerM[handlerExpr.Text()] = Holder
}
final.Service = append(final.Service, service)
}
}
for _, route := range service.ServiceApi.ServiceRoute {
uniqueRoute := fmt.Sprintf("%s %s", route.Route.Method.Text(), route.Route.Path.Text())
if _, ok := final.routeM[uniqueRoute]; ok {
v.panic(route.Route.Method, fmt.Sprintf("duplicate route '%s'", uniqueRoute))
}
func (v *ApiVisitor) duplicateServerItemCheck(service *Service) {
if service.AtServer != nil {
atServerM := map[string]PlaceHolder{}
for _, kv := range service.AtServer.Kv {
if _, ok := atServerM[kv.Key.Text()]; ok {
v.panic(kv.Key, fmt.Sprintf("duplicate key '%s'", kv.Key.Text()))
}
final.routeM[uniqueRoute] = Holder
var handlerExpr Expr
if route.AtServer != nil {
atServerM := map[string]PlaceHolder{}
for _, kv := range route.AtServer.Kv {
if _, ok := atServerM[kv.Key.Text()]; ok {
v.panic(kv.Key, fmt.Sprintf("duplicate key '%s'", kv.Key.Text()))
}
atServerM[kv.Key.Text()] = Holder
if kv.Key.Text() == "handler" {
handlerExpr = kv.Value
}
}
}
atServerM[kv.Key.Text()] = Holder
}
}
}
if route.AtHandler != nil {
handlerExpr = route.AtHandler.Name
}
func (v *ApiVisitor) acceptType(root *Api, final *Api) {
for _, tp := range root.Type {
if _, ok := final.typeM[tp.NameExpr().Text()]; ok {
v.panic(tp.NameExpr(), fmt.Sprintf("duplicate type '%s'", tp.NameExpr().Text()))
}
if handlerExpr == nil {
v.panic(route.Route.Method, fmt.Sprintf("mismtached handler"))
}
final.typeM[tp.NameExpr().Text()] = Holder
final.Type = append(final.Type, tp)
}
}
if handlerExpr.Text() == "" {
v.panic(handlerExpr, fmt.Sprintf("mismtached handler"))
}
func (v *ApiVisitor) acceptInfo(root *Api, final *Api) {
if root.Info != nil {
infoM := map[string]PlaceHolder{}
if final.Info != nil {
v.panic(root.Info.Info, fmt.Sprintf("mutiple info declaration"))
}
if _, ok := final.handlerM[handlerExpr.Text()]; ok {
v.panic(handlerExpr, fmt.Sprintf("duplicate handler '%s'", handlerExpr.Text()))
}
final.handlerM[handlerExpr.Text()] = Holder
for _, value := range root.Info.Kvs {
if _, ok := infoM[value.Key.Text()]; ok {
v.panic(value.Key, fmt.Sprintf("duplicate key '%s'", value.Key.Text()))
}
final.Service = append(final.Service, service)
infoM[value.Key.Text()] = Holder
}
final.Info = root.Info
}
}
return &final
func (v *ApiVisitor) accetpImport(root *Api, final *Api) {
for _, imp := range root.Import {
if _, ok := final.importM[imp.Value.Text()]; ok {
v.panic(imp.Import, fmt.Sprintf("duplicate import '%s'", imp.Value.Text()))
}
final.importM[imp.Value.Text()] = Holder
final.Import = append(final.Import, imp)
}
}
func (v *ApiVisitor) acceptSyntax(root *Api, final *Api) {
if root.Syntax != nil {
if final.Syntax != nil {
v.panic(root.Syntax.Syntax, fmt.Sprintf("mutiple syntax declaration"))
}
final.Syntax = root.Syntax
}
}
func (v *ApiVisitor) VisitSpec(ctx *api.SpecContext) interface{} {

@ -156,28 +156,9 @@ func (p *Parser) invoke(linePrefix, content string) (v *Api, err error) {
}
func (p *Parser) valid(mainApi *Api, nestedApi *Api) error {
if len(nestedApi.Import) > 0 {
importToken := nestedApi.Import[0].Import
return fmt.Errorf("%s line %d:%d the nested api does not support import",
nestedApi.LinePrefix, importToken.Line(), importToken.Column())
}
if mainApi.Syntax != nil && nestedApi.Syntax != nil {
if mainApi.Syntax.Version.Text() != nestedApi.Syntax.Version.Text() {
syntaxToken := nestedApi.Syntax.Syntax
return fmt.Errorf("%s line %d:%d multiple syntax declaration, expecting syntax '%s', but found '%s'",
nestedApi.LinePrefix, syntaxToken.Line(), syntaxToken.Column(), mainApi.Syntax.Version.Text(), nestedApi.Syntax.Version.Text())
}
}
if len(mainApi.Service) > 0 {
mainService := mainApi.Service[0]
for _, service := range nestedApi.Service {
if mainService.ServiceApi.Name.Text() != service.ServiceApi.Name.Text() {
return fmt.Errorf("%s multiple service name declaration, expecting service name '%s', but found '%s'",
nestedApi.LinePrefix, mainService.ServiceApi.Name.Text(), service.ServiceApi.Name.Text())
}
}
err := p.nestedApiCheck(mainApi, nestedApi)
if err != nil {
return err
}
mainHandlerMap := make(map[string]PlaceHolder)
@ -218,6 +199,23 @@ func (p *Parser) valid(mainApi *Api, nestedApi *Api) error {
}
// duplicate route check
err = p.duplicateRouteCheck(nestedApi, mainHandlerMap, mainRouteMap)
if err != nil {
return err
}
// duplicate type check
for _, each := range nestedApi.Type {
if _, ok := mainTypeMap[each.NameExpr().Text()]; ok {
return fmt.Errorf("%s line %d:%d duplicate type declaration '%s'",
nestedApi.LinePrefix, each.NameExpr().Line(), each.NameExpr().Column(), each.NameExpr().Text())
}
}
return nil
}
func (p *Parser) duplicateRouteCheck(nestedApi *Api, mainHandlerMap map[string]PlaceHolder, mainRouteMap map[string]PlaceHolder) error {
for _, each := range nestedApi.Service {
for _, r := range each.ServiceApi.ServiceRoute {
handler := r.GetHandler()
@ -237,12 +235,31 @@ func (p *Parser) valid(mainApi *Api, nestedApi *Api) error {
}
}
}
return nil
}
// duplicate type check
for _, each := range nestedApi.Type {
if _, ok := mainTypeMap[each.NameExpr().Text()]; ok {
return fmt.Errorf("%s line %d:%d duplicate type declaration '%s'",
nestedApi.LinePrefix, each.NameExpr().Line(), each.NameExpr().Column(), each.NameExpr().Text())
func (p *Parser) nestedApiCheck(mainApi *Api, nestedApi *Api) error {
if len(nestedApi.Import) > 0 {
importToken := nestedApi.Import[0].Import
return fmt.Errorf("%s line %d:%d the nested api does not support import",
nestedApi.LinePrefix, importToken.Line(), importToken.Column())
}
if mainApi.Syntax != nil && nestedApi.Syntax != nil {
if mainApi.Syntax.Version.Text() != nestedApi.Syntax.Version.Text() {
syntaxToken := nestedApi.Syntax.Syntax
return fmt.Errorf("%s line %d:%d multiple syntax declaration, expecting syntax '%s', but found '%s'",
nestedApi.LinePrefix, syntaxToken.Line(), syntaxToken.Column(), mainApi.Syntax.Version.Text(), nestedApi.Syntax.Version.Text())
}
}
if len(mainApi.Service) > 0 {
mainService := mainApi.Service[0]
for _, service := range nestedApi.Service {
if mainService.ServiceApi.Name.Text() != service.ServiceApi.Name.Text() {
return fmt.Errorf("%s multiple service name declaration, expecting service name '%s', but found '%s'",
nestedApi.LinePrefix, mainService.ServiceApi.Name.Text(), service.ServiceApi.Name.Text())
}
}
}
return nil
@ -276,55 +293,51 @@ func (p *Parser) checkTypeDeclaration(apiList []*Api) error {
for _, apiItem := range apiList {
linePrefix := apiItem.LinePrefix
for _, each := range apiItem.Type {
tp, ok := each.(*TypeStruct)
if !ok {
continue
}
err := p.checkTypes(apiItem, linePrefix, types)
if err != nil {
return err
}
for _, member := range tp.Fields {
err := p.checkType(linePrefix, types, member.DataType)
if err != nil {
return err
}
}
err = p.checkServices(apiItem, types, linePrefix)
if err != nil {
return err
}
}
return nil
}
for _, service := range apiItem.Service {
for _, each := range service.ServiceApi.ServiceRoute {
route := each.Route
if route.Req != nil && route.Req.Name.IsNotNil() && route.Req.Name.Expr().IsNotNil() {
_, ok := types[route.Req.Name.Expr().Text()]
if !ok {
return fmt.Errorf("%s line %d:%d can not found declaration '%s' in context",
linePrefix, route.Req.Name.Expr().Line(), route.Req.Name.Expr().Column(), route.Req.Name.Expr().Text())
}
}
func (p *Parser) checkServices(apiItem *Api, types map[string]TypeExpr, linePrefix string) error {
for _, service := range apiItem.Service {
for _, each := range service.ServiceApi.ServiceRoute {
route := each.Route
err := p.checkRequestBody(route, types, linePrefix)
if err != nil {
return err
}
if route.Reply != nil && route.Reply.Name.IsNotNil() && route.Reply.Name.Expr().IsNotNil() {
reply := route.Reply.Name
var structName string
switch tp := reply.(type) {
if route.Reply != nil && route.Reply.Name.IsNotNil() && route.Reply.Name.Expr().IsNotNil() {
reply := route.Reply.Name
var structName string
switch tp := reply.(type) {
case *Literal:
structName = tp.Literal.Text()
case *Array:
switch innerTp := tp.Literal.(type) {
case *Literal:
structName = tp.Literal.Text()
case *Array:
switch innerTp := tp.Literal.(type) {
case *Literal:
structName = innerTp.Literal.Text()
case *Pointer:
structName = innerTp.Name.Text()
}
structName = innerTp.Literal.Text()
case *Pointer:
structName = innerTp.Name.Text()
}
}
if api.IsBasicType(structName) {
continue
}
if api.IsBasicType(structName) {
continue
}
_, ok := types[structName]
if !ok {
return fmt.Errorf("%s line %d:%d can not found declaration '%s' in context",
linePrefix, route.Reply.Name.Expr().Line(), route.Reply.Name.Expr().Column(), structName)
}
_, ok := types[structName]
if !ok {
return fmt.Errorf("%s line %d:%d can not found declaration '%s' in context",
linePrefix, route.Reply.Name.Expr().Line(), route.Reply.Name.Expr().Column(), structName)
}
}
}
@ -332,6 +345,34 @@ func (p *Parser) checkTypeDeclaration(apiList []*Api) error {
return nil
}
func (p *Parser) checkRequestBody(route *Route, types map[string]TypeExpr, linePrefix string) error {
if route.Req != nil && route.Req.Name.IsNotNil() && route.Req.Name.Expr().IsNotNil() {
_, ok := types[route.Req.Name.Expr().Text()]
if !ok {
return fmt.Errorf("%s line %d:%d can not found declaration '%s' in context",
linePrefix, route.Req.Name.Expr().Line(), route.Req.Name.Expr().Column(), route.Req.Name.Expr().Text())
}
}
return nil
}
func (p *Parser) checkTypes(apiItem *Api, linePrefix string, types map[string]TypeExpr) error {
for _, each := range apiItem.Type {
tp, ok := each.(*TypeStruct)
if !ok {
continue
}
for _, member := range tp.Fields {
err := p.checkType(linePrefix, types, member.DataType)
if err != nil {
return err
}
}
}
return nil
}
func (p *Parser) checkType(linePrefix string, types map[string]TypeExpr, expr DataType) error {
if expr == nil {
return nil

@ -213,13 +213,7 @@ func (p parser) fillService() error {
var groups []spec.Group
for _, item := range p.ast.Service {
var group spec.Group
if item.AtServer != nil {
var properties = make(map[string]string, 0)
for _, kv := range item.AtServer.Kv {
properties[kv.Key.Text()] = kv.Value.Text()
}
group.Annotation.Properties = properties
}
p.fillAtServer(item, &group)
for _, astRoute := range item.ServiceApi.ServiceRoute {
route := spec.Route{
@ -231,25 +225,9 @@ func (p parser) fillService() error {
route.Handler = astRoute.AtHandler.Name.Text()
}
if astRoute.AtServer != nil {
var properties = make(map[string]string, 0)
for _, kv := range astRoute.AtServer.Kv {
properties[kv.Key.Text()] = kv.Value.Text()
}
route.Annotation.Properties = properties
if len(route.Handler) == 0 {
route.Handler = properties["handler"]
}
if len(route.Handler) == 0 {
return fmt.Errorf("missing handler annotation for %q", route.Path)
}
for _, char := range route.Handler {
if !unicode.IsDigit(char) && !unicode.IsLetter(char) {
return fmt.Errorf("route [%s] handler [%s] invalid, handler name should only contains letter or digit",
route.Path, route.Handler)
}
}
err := p.fillRouteAtServer(astRoute, &route)
if err != nil {
return err
}
if astRoute.Route.Req != nil {
@ -269,7 +247,7 @@ func (p parser) fillService() error {
}
}
err := p.fillRouteType(&route)
err = p.fillRouteType(&route)
if err != nil {
return err
}
@ -289,6 +267,40 @@ func (p parser) fillService() error {
return nil
}
func (p parser) fillRouteAtServer(astRoute *ast.ServiceRoute, route *spec.Route) error {
if astRoute.AtServer != nil {
var properties = make(map[string]string, 0)
for _, kv := range astRoute.AtServer.Kv {
properties[kv.Key.Text()] = kv.Value.Text()
}
route.Annotation.Properties = properties
if len(route.Handler) == 0 {
route.Handler = properties["handler"]
}
if len(route.Handler) == 0 {
return fmt.Errorf("missing handler annotation for %q", route.Path)
}
for _, char := range route.Handler {
if !unicode.IsDigit(char) && !unicode.IsLetter(char) {
return fmt.Errorf("route [%s] handler [%s] invalid, handler name should only contains letter or digit",
route.Path, route.Handler)
}
}
}
return nil
}
func (p parser) fillAtServer(item *ast.Service, group *spec.Group) {
if item.AtServer != nil {
var properties = make(map[string]string, 0)
for _, kv := range item.AtServer.Kv {
properties[kv.Key.Text()] = kv.Value.Text()
}
group.Annotation.Properties = properties
}
}
func (p parser) fillRouteType(route *spec.Route) error {
if route.RequestType != nil {
switch route.RequestType.(type) {

@ -1,58 +0,0 @@
package util
import (
"strconv"
"strings"
)
func TagLookup(tag, key string) (value string, ok bool) {
tag = strings.Replace(tag, "`", "", -1)
for tag != "" {
// Skip leading space.
i := 0
for i < len(tag) && tag[i] == ' ' {
i++
}
tag = tag[i:]
if tag == "" {
break
}
// Scan to colon. A space, a quote or a control character is a syntax error.
// Strictly speaking, control chars include the range [0x7f, 0x9f], not just
// [0x00, 0x1f], but in practice, we ignore the multi-byte control characters
// as it is simpler to inspect the tag's bytes than the tag's runes.
i = 0
for i < len(tag) && tag[i] > ' ' && tag[i] != ':' && tag[i] != '"' && tag[i] != 0x7f {
i++
}
if i == 0 || i+1 >= len(tag) || tag[i] != ':' || tag[i+1] != '"' {
break
}
name := string(tag[:i])
tag = tag[i+1:]
// Scan quoted string to find value.
i = 1
for i < len(tag) && tag[i] != '"' {
if tag[i] == '\\' {
i++
}
i++
}
if i >= len(tag) {
break
}
qvalue := string(tag[:i+1])
tag = tag[i+1:]
if key == name {
value, err := strconv.Unquote(qvalue)
if err != nil {
break
}
return value, true
}
}
return "", false
}

@ -28,7 +28,7 @@ import (
)
var (
BuildVersion = "1.1.5"
buildVersion = "1.1.5"
commands = []cli.Command{
{
Name: "upgrade",
@ -510,7 +510,7 @@ func main() {
app := cli.NewApp()
app.Usage = "a cli tool to generate code"
app.Version = fmt.Sprintf("%s %s/%s", BuildVersion, runtime.GOOS, runtime.GOARCH)
app.Version = fmt.Sprintf("%s %s/%s", buildVersion, runtime.GOOS, runtime.GOARCH)
app.Commands = commands
// cli already print error messages
if err := app.Run(os.Args); err != nil {

@ -1,6 +1,7 @@
package gen
import (
"bytes"
"fmt"
"io/ioutil"
"os"
@ -31,7 +32,20 @@ type (
pkg string
cfg *config.Config
}
Option func(generator *defaultGenerator)
code struct {
importsCode string
varsCode string
typesCode string
newCode string
insertCode string
findCode []string
updateCode string
deleteCode string
cacheExtra string
}
)
func NewDefaultGenerator(dir string, cfg *config.Config, opt ...Option) (*defaultGenerator, error) {
@ -186,15 +200,6 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
return "", fmt.Errorf("table %s: missing primary key", in.Name.Source())
}
text, err := util.LoadTemplate(category, modelTemplateFile, template.Model)
if err != nil {
return "", err
}
t := util.With("model").
Parse(text).
GoFmt(true)
m, err := genCacheKeys(in)
if err != nil {
return "", err
@ -261,18 +266,19 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
return "", err
}
output, err := t.Execute(map[string]interface{}{
"pkg": g.pkg,
"imports": importsCode,
"vars": varsCode,
"types": typesCode,
"new": newCode,
"insert": insertCode,
"find": strings.Join(findCode, "\n"),
"update": updateCode,
"delete": deleteCode,
"extraMethod": ret.cacheExtra,
})
code := &code{
importsCode: importsCode,
varsCode: varsCode,
typesCode: typesCode,
newCode: newCode,
insertCode: insertCode,
findCode: findCode,
updateCode: updateCode,
deleteCode: deleteCode,
cacheExtra: ret.cacheExtra,
}
output, err := g.executeModel(code)
if err != nil {
return "", err
}
@ -280,6 +286,32 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
return output.String(), nil
}
func (g *defaultGenerator) executeModel(code *code) (*bytes.Buffer, error) {
text, err := util.LoadTemplate(category, modelTemplateFile, template.Model)
if err != nil {
return nil, err
}
t := util.With("model").
Parse(text).
GoFmt(true)
output, err := t.Execute(map[string]interface{}{
"pkg": g.pkg,
"imports": code.importsCode,
"vars": code.varsCode,
"types": code.typesCode,
"new": code.newCode,
"insert": code.insertCode,
"find": strings.Join(code.findCode, "\n"),
"update": code.updateCode,
"delete": code.deleteCode,
"extraMethod": code.cacheExtra,
})
if err != nil {
return nil, err
}
return output, nil
}
func wrapWithRawString(v string) string {
if v == "`" {
return v

@ -68,40 +68,24 @@ func Parse(ddl string) (*Table, error) {
columns := tableSpec.Columns
indexes := tableSpec.Indexes
keyMap := make(map[string]KeyType)
for _, index := range indexes {
info := index.Info
if info == nil {
continue
}
if info.Primary {
if len(index.Columns) > 1 {
return nil, errPrimaryKey
}
keyMap, err := getIndexKeyType(indexes)
if err != nil {
return nil, err
}
keyMap[index.Columns[0].Column.String()] = primary
continue
}
// can optimize
if len(index.Columns) > 1 {
continue
}
column := index.Columns[0]
columnName := column.Column.String()
camelColumnName := stringx.From(columnName).ToCamel()
// by default, createTime|updateTime findOne is not used.
if camelColumnName == "CreateTime" || camelColumnName == "UpdateTime" {
continue
}
if info.Unique {
keyMap[columnName] = unique
} else if info.Spatial {
keyMap[columnName] = spatial
} else {
keyMap[columnName] = normal
}
fields, primaryKey, err := convertFileds(columns, keyMap)
if err != nil {
return nil, err
}
return &Table{
Name: stringx.From(tableName),
PrimaryKey: primaryKey,
Fields: fields,
}, nil
}
func convertFileds(columns []*sqlparser.ColumnDefinition, keyMap map[string]KeyType) ([]Field, Primary, error) {
var fields []Field
var primaryKey Primary
for _, column := range columns {
@ -124,7 +108,7 @@ func Parse(ddl string) (*Table, error) {
}
dataType, err := converter.ConvertDataType(column.Type.Type, isDefaultNull)
if err != nil {
return nil, err
return nil, primaryKey, err
}
var field Field
@ -145,12 +129,44 @@ func Parse(ddl string) (*Table, error) {
}
fields = append(fields, field)
}
return fields, primaryKey, nil
}
return &Table{
Name: stringx.From(tableName),
PrimaryKey: primaryKey,
Fields: fields,
}, nil
func getIndexKeyType(indexes []*sqlparser.IndexDefinition) (map[string]KeyType, error) {
keyMap := make(map[string]KeyType)
for _, index := range indexes {
info := index.Info
if info == nil {
continue
}
if info.Primary {
if len(index.Columns) > 1 {
return nil, errPrimaryKey
}
keyMap[index.Columns[0].Column.String()] = primary
continue
}
// can optimize
if len(index.Columns) > 1 {
continue
}
column := index.Columns[0]
columnName := column.Column.String()
camelColumnName := stringx.From(columnName).ToCamel()
// by default, createTime|updateTime findOne is not used.
if camelColumnName == "CreateTime" || camelColumnName == "UpdateTime" {
continue
}
if info.Unique {
keyMap[columnName] = unique
} else if info.Spatial {
keyMap[columnName] = spatial
} else {
keyMap[columnName] = normal
}
}
return keyMap, nil
}
func (t *Table) ContainsTime() bool {

@ -8,6 +8,9 @@ import (
)
type (
// Console wraps from the fmt.Sprintf,
// by default, it implemented the colorConsole to provide the colorful output to the consle
// and the ideaConsole to output with prefix for the plugin of intellij
Console interface {
Success(format string, a ...interface{})
Info(format string, a ...interface{})
@ -25,6 +28,7 @@ type (
}
)
// NewConsole returns a instance of Console
func NewConsole(idea bool) Console {
if idea {
return NewIdeaConsole()
@ -32,7 +36,8 @@ func NewConsole(idea bool) Console {
return NewColorConsole()
}
func NewColorConsole() *colorConsole {
// NewColorConsole returns a instance of colorConsole
func NewColorConsole() Console {
return &colorConsole{}
}
@ -76,7 +81,8 @@ func (c *colorConsole) Must(err error) {
}
}
func NewIdeaConsole() *ideaConsole {
// NewIdeaConsole returns a instace of ideaConsole
func NewIdeaConsole() Console {
return &ideaConsole{}
}

@ -9,6 +9,8 @@ import (
var errModuleCheck = errors.New("the work directory must be found in the go mod or the $GOPATH")
// ProjectContext is a structure for the project,
// which contains WorkDir, Name, Path and Dir
type ProjectContext struct {
WorkDir string
// Name is the root name of the project

@ -9,6 +9,8 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
)
// Module contains the relative data of go module,
// which is the result of the command go list
type Module struct {
Path string
Main bool

@ -10,9 +10,7 @@ import (
"github.com/logrusorgru/aurora"
)
const (
NL = "\n"
)
const NL = "\n"
func CreateIfNotExist(file string) (*os.File, error) {
_, err := os.Stat(file)

@ -18,6 +18,7 @@ const (
upper
)
// ErrNamingFormat defines an error for unknown fomat
var ErrNamingFormat = errors.New("unsupported format")
type (

@ -1,3 +1,5 @@
// Package name provides methods to verify naming style and format naming style
// See the method IsNamingValid, FormatFilename
package name
import (
@ -6,11 +8,15 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
)
// NamingStyle the type of string
type NamingStyle = string
const (
// NamingLower defines the lower spell case
NamingLower NamingStyle = "lower"
// NamingCamel defines the camel spell case
NamingCamel NamingStyle = "camel"
// NamingSnake defines the snake spell case
NamingSnake NamingStyle = "snake"
)
@ -29,6 +35,8 @@ func IsNamingValid(namingStyle string) (NamingStyle, bool) {
}
}
// FormatFilename converts the filename string to the target
// naming style by calling method of stringx
func FormatFilename(filename string, style NamingStyle) string {
switch style {
case NamingCamel:

@ -6,14 +6,17 @@ import (
"unicode"
)
// String provides for coverting the source text into other spell case,like lower,snake,camel
type String struct {
source string
}
// From converts the input text to String and returns it
func From(data string) String {
return String{source: data}
}
// IsEmptyOrSpace returns true if the length of the string value is 0 after call strings.TrimSpace, or else returns false
func (s String) IsEmptyOrSpace() bool {
if len(s.source) == 0 {
return true
@ -24,18 +27,22 @@ func (s String) IsEmptyOrSpace() bool {
return false
}
// Lower calls the strings.ToLower
func (s String) Lower() string {
return strings.ToLower(s.source)
}
// ReplaceAll calls the strings.ReplaceAll
func (s String) ReplaceAll(old, new string) string {
return strings.ReplaceAll(s.source, old, new)
}
//Source returns the source string value
func (s String) Source() string {
return s.source
}
// Title calls the strings.Title
func (s String) Title() string {
if s.IsEmptyOrSpace() {
return s.source
@ -43,7 +50,7 @@ func (s String) Title() string {
return strings.Title(s.source)
}
// snake->camel(upper start)
// ToCamel converts the input text into camel case
func (s String) ToCamel() string {
list := s.splitBy(func(r rune) bool {
return r == '_'
@ -55,7 +62,7 @@ func (s String) ToCamel() string {
return strings.Join(target, "")
}
// camel->snake
// ToSnake converts the input text into snake case
func (s String) ToSnake() string {
list := s.splitBy(unicode.IsUpper, false)
var target []string
@ -65,7 +72,7 @@ func (s String) ToSnake() string {
return strings.Join(target, "_")
}
// return original string if rune is not letter at index 0
// Untitle return the original string if rune is not letter at index 0
func (s String) Untitle() string {
if s.IsEmptyOrSpace() {
return s.source
@ -77,10 +84,6 @@ func (s String) Untitle() string {
return string(unicode.ToLower(r)) + s.source[1:]
}
func (s String) Upper() string {
return strings.ToUpper(s.source)
}
// it will not ignore spaces
func (s String) splitBy(fn func(r rune) bool, remove bool) []string {
if s.IsEmptyOrSpace() {

@ -9,29 +9,35 @@ import (
const regularPerm = 0666
type defaultTemplate struct {
// DefaultTemplate is a tool to provides the text/template operations
type DefaultTemplate struct {
name string
text string
goFmt bool
savePath string
}
func With(name string) *defaultTemplate {
return &defaultTemplate{
// With returns a instace of DefaultTemplate
func With(name string) *DefaultTemplate {
return &DefaultTemplate{
name: name,
}
}
func (t *defaultTemplate) Parse(text string) *defaultTemplate {
// Parse accepts a source template and returns DefaultTemplate
func (t *DefaultTemplate) Parse(text string) *DefaultTemplate {
t.text = text
return t
}
func (t *defaultTemplate) GoFmt(format bool) *defaultTemplate {
// GoFmt sets the value to goFmt and marks the generated codes will be formated or not
func (t *DefaultTemplate) GoFmt(format bool) *DefaultTemplate {
t.goFmt = format
return t
}
func (t *defaultTemplate) SaveTo(data interface{}, path string, forceUpdate bool) error {
// SaveTo writes the codes to the target path
func (t *DefaultTemplate) SaveTo(data interface{}, path string, forceUpdate bool) error {
if FileExists(path) && !forceUpdate {
return nil
}
@ -44,7 +50,8 @@ func (t *defaultTemplate) SaveTo(data interface{}, path string, forceUpdate bool
return ioutil.WriteFile(path, output.Bytes(), regularPerm)
}
func (t *defaultTemplate) Execute(data interface{}) (*bytes.Buffer, error) {
// Execute returns the codes after the template executed
func (t *DefaultTemplate) Execute(data interface{}) (*bytes.Buffer, error) {
tem, err := template.New(t.name).Parse(t.text)
if err != nil {
return nil, err

@ -1,9 +1,14 @@
package vars
const (
ProjectName = "zero"
ProjectOpenSourceUrl = "github.com/tal-tech/go-zero"
OsWindows = "windows"
OsMac = "darwin"
OsLinux = "linux"
// ProjectName the const value of zero
ProjectName = "zero"
// ProjectOpenSourceURL the githb url of go-zero
ProjectOpenSourceURL = "github.com/tal-tech/go-zero"
// OsWindows windows os
OsWindows = "windows"
// OsMac mac os
OsMac = "darwin"
// OsLinux linux os
OsLinux = "linux"
)

Loading…
Cancel
Save