rpc generation fix (#184)

* reactor alert

* optimize

* add test case

* update the target directory in case proto contains option

* fix missing comments and format code
master v1.0.25
Keson 4 years ago committed by GitHub
parent f7d778e0ed
commit 856b5aadb1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,10 +18,10 @@ func Rpc(c *cli.Context) error {
out := c.String("dir") out := c.String("dir")
protoImportPath := c.StringSlice("proto_path") protoImportPath := c.StringSlice("proto_path")
if len(src) == 0 { if len(src) == 0 {
return errors.New("the proto source can not be nil") return errors.New("missing -src")
} }
if len(out) == 0 { if len(out) == 0 {
return errors.New("the target directory can not be nil") return errors.New("missing -dir")
} }
g := generator.NewDefaultRpcGenerator() g := generator.NewDefaultRpcGenerator()
return g.Generate(src, out, protoImportPath) return g.Generate(src, out, protoImportPath)

@ -11,6 +11,7 @@ import (
) )
func TestRpcGenerateCaseNilImport(t *testing.T) { func TestRpcGenerateCaseNilImport(t *testing.T) {
_ = Clean()
dispatcher := NewDefaultGenerator() dispatcher := NewDefaultGenerator()
if err := dispatcher.Prepare(); err == nil { if err := dispatcher.Prepare(); err == nil {
g := NewRpcGenerator(dispatcher) g := NewRpcGenerator(dispatcher)
@ -29,6 +30,7 @@ func TestRpcGenerateCaseNilImport(t *testing.T) {
} }
func TestRpcGenerateCaseOption(t *testing.T) { func TestRpcGenerateCaseOption(t *testing.T) {
_ = Clean()
dispatcher := NewDefaultGenerator() dispatcher := NewDefaultGenerator()
if err := dispatcher.Prepare(); err == nil { if err := dispatcher.Prepare(); err == nil {
g := NewRpcGenerator(dispatcher) g := NewRpcGenerator(dispatcher)
@ -47,6 +49,7 @@ func TestRpcGenerateCaseOption(t *testing.T) {
} }
func TestRpcGenerateCaseWordOption(t *testing.T) { func TestRpcGenerateCaseWordOption(t *testing.T) {
_ = Clean()
dispatcher := NewDefaultGenerator() dispatcher := NewDefaultGenerator()
if err := dispatcher.Prepare(); err == nil { if err := dispatcher.Prepare(); err == nil {
g := NewRpcGenerator(dispatcher) g := NewRpcGenerator(dispatcher)
@ -66,6 +69,7 @@ func TestRpcGenerateCaseWordOption(t *testing.T) {
// test keyword go // test keyword go
func TestRpcGenerateCaseGoOption(t *testing.T) { func TestRpcGenerateCaseGoOption(t *testing.T) {
_ = Clean()
dispatcher := NewDefaultGenerator() dispatcher := NewDefaultGenerator()
if err := dispatcher.Prepare(); err == nil { if err := dispatcher.Prepare(); err == nil {
g := NewRpcGenerator(dispatcher) g := NewRpcGenerator(dispatcher)
@ -84,6 +88,7 @@ func TestRpcGenerateCaseGoOption(t *testing.T) {
} }
func TestRpcGenerateCaseImport(t *testing.T) { func TestRpcGenerateCaseImport(t *testing.T) {
_ = Clean()
dispatcher := NewDefaultGenerator() dispatcher := NewDefaultGenerator()
if err := dispatcher.Prepare(); err == nil { if err := dispatcher.Prepare(); err == nil {
g := NewRpcGenerator(dispatcher) g := NewRpcGenerator(dispatcher)
@ -102,3 +107,22 @@ func TestRpcGenerateCaseImport(t *testing.T) {
}()) }())
} }
} }
func TestRpcGenerateCaseServiceRpcNamingSnake(t *testing.T) {
_ = Clean()
dispatcher := NewDefaultGenerator()
if err := dispatcher.Prepare(); err == nil {
g := NewRpcGenerator(dispatcher)
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
err = g.Generate("./test_service_rpc_naming_snake.proto", abs, nil)
defer func() {
_ = os.RemoveAll(abs)
}()
assert.Nil(t, err)
_, err = execx.Run("go test "+abs, abs)
assert.Nil(t, err)
}
}

@ -52,7 +52,7 @@ func New{{.serviceName}}(cli zrpc.Client) {{.serviceName}} {
callFunctionTemplate = ` callFunctionTemplate = `
{{if .hasComment}}{{.comment}}{{end}} {{if .hasComment}}{{.comment}}{{end}}
func (m *default{{.rpcServiceName}}) {{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}}, error) { func (m *default{{.serviceName}}) {{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}}, error) {
client := {{.package}}.New{{.rpcServiceName}}Client(m.cli.Conn()) client := {{.package}}.New{{.rpcServiceName}}Client(m.cli.Conn())
return client.{{.method}}(ctx, in) return client.{{.method}}(ctx, in)
} }
@ -90,9 +90,9 @@ func (g *defaultGenerator) GenCall(ctx DirContext, proto parser.Proto) error {
"name": formatFilename(service.Name), "name": formatFilename(service.Name),
"alias": strings.Join(alias.KeysStr(), util.NL), "alias": strings.Join(alias.KeysStr(), util.NL),
"head": head, "head": head,
"filePackage": formatFilename(service.Name), "filePackage": dir.Base,
"package": fmt.Sprintf(`"%s"`, ctx.GetPb().Package), "package": fmt.Sprintf(`"%s"`, ctx.GetPb().Package),
"serviceName": parser.CamelCase(service.Name), "serviceName": stringx.From(service.Name).ToCamel(),
"functions": strings.Join(functions, util.NL), "functions": strings.Join(functions, util.NL),
"interface": strings.Join(iFunctions, util.NL), "interface": strings.Join(iFunctions, util.NL),
}, filename, true) }, filename, true)
@ -109,8 +109,9 @@ func (g *defaultGenerator) genFunction(goPackage string, service parser.Service)
comment := parser.GetComment(rpc.Doc()) comment := parser.GetComment(rpc.Doc())
buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]interface{}{ buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]interface{}{
"rpcServiceName": stringx.From(service.Name).Title(), "serviceName": stringx.From(service.Name).ToCamel(),
"method": stringx.From(rpc.Name).Title(), "rpcServiceName": parser.CamelCase(service.Name),
"method": parser.CamelCase(rpc.Name),
"package": goPackage, "package": goPackage,
"pbRequest": parser.CamelCase(rpc.RequestType), "pbRequest": parser.CamelCase(rpc.RequestType),
"pbResponse": parser.CamelCase(rpc.ReturnsType), "pbResponse": parser.CamelCase(rpc.ReturnsType),
@ -140,7 +141,7 @@ func (g *defaultGenerator) getInterfaceFuncs(service parser.Service) ([]string,
map[string]interface{}{ map[string]interface{}{
"hasComment": len(comment) > 0, "hasComment": len(comment) > 0,
"comment": comment, "comment": comment,
"method": stringx.From(rpc.Name).Title(), "method": parser.CamelCase(rpc.Name),
"pbRequest": parser.CamelCase(rpc.RequestType), "pbRequest": parser.CamelCase(rpc.RequestType),
"pbResponse": parser.CamelCase(rpc.ReturnsType), "pbResponse": parser.CamelCase(rpc.ReturnsType),
}) })

@ -63,7 +63,7 @@ func (g *defaultGenerator) GenLogic(ctx DirContext, proto parser.Proto) error {
return err return err
} }
err = util.With("logic").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{ err = util.With("logic").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"logicName": fmt.Sprintf("%sLogic", stringx.From(rpc.Name).Title()), "logicName": fmt.Sprintf("%sLogic", stringx.From(rpc.Name).ToCamel()),
"functions": functions, "functions": functions,
"imports": strings.Join(imports.KeysStr(), util.NL), "imports": strings.Join(imports.KeysStr(), util.NL),
}, filename, false) }, filename, false)

@ -7,6 +7,7 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser" "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
) )
const mainTemplate = `{{.head}} const mainTemplate = `{{.head}}
@ -32,7 +33,7 @@ func main() {
var c config.Config var c config.Config
conf.MustLoad(*configFile, &c) conf.MustLoad(*configFile, &c)
ctx := svc.NewServiceContext(c) ctx := svc.NewServiceContext(c)
srv := server.New{{.service}}Server(ctx) srv := server.New{{.serviceNew}}Server(ctx)
s := zrpc.MustNewServer(c.RpcServerConf, func(grpcServer *grpc.Server) { s := zrpc.MustNewServer(c.RpcServerConf, func(grpcServer *grpc.Server) {
{{.pkg}}.Register{{.service}}Server(grpcServer, srv) {{.pkg}}.Register{{.service}}Server(grpcServer, srv)
@ -65,6 +66,7 @@ func (g *defaultGenerator) GenMain(ctx DirContext, proto parser.Proto) error {
"serviceName": serviceNameLower, "serviceName": serviceNameLower,
"imports": strings.Join(imports, util.NL), "imports": strings.Join(imports, util.NL),
"pkg": proto.PbPackage, "pkg": proto.PbPackage,
"serviceNew": stringx.From(proto.Service.Name).ToCamel(),
"service": parser.CamelCase(proto.Service.Name), "service": parser.CamelCase(proto.Service.Name),
}, fileName, false) }, fileName, false)
} }

@ -20,7 +20,7 @@ func (g *defaultGenerator) GenPb(ctx DirContext, protoImportPath []string, proto
cw.WriteString(" -I=" + base) cw.WriteString(" -I=" + base)
cw.WriteString(" " + proto.Name) cw.WriteString(" " + proto.Name)
if strings.Contains(proto.GoPackage, "/") { if strings.Contains(proto.GoPackage, "/") {
cw.WriteString(" --go_out=plugins=grpc:" + ctx.GetInternal().Filename) cw.WriteString(" --go_out=plugins=grpc:" + ctx.GetMain().Filename)
} else { } else {
cw.WriteString(" --go_out=plugins=grpc:" + dir.Filename) cw.WriteString(" --go_out=plugins=grpc:" + dir.Filename)
} }

@ -67,7 +67,7 @@ func (g *defaultGenerator) GenServer(ctx DirContext, proto parser.Proto) error {
err = util.With("server").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{ err = util.With("server").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"head": head, "head": head,
"server": stringx.From(service.Name).Title(), "server": stringx.From(service.Name).ToCamel(),
"imports": strings.Join(imports.KeysStr(), util.NL), "imports": strings.Join(imports.KeysStr(), util.NL),
"funcs": strings.Join(funcList, util.NL), "funcs": strings.Join(funcList, util.NL),
}, serverFile, true) }, serverFile, true)
@ -84,8 +84,8 @@ func (g *defaultGenerator) genFunctions(goPackage string, service parser.Service
comment := parser.GetComment(rpc.Doc()) comment := parser.GetComment(rpc.Doc())
buffer, err := util.With("func").Parse(text).Execute(map[string]interface{}{ buffer, err := util.With("func").Parse(text).Execute(map[string]interface{}{
"server": stringx.From(service.Name).Title(), "server": stringx.From(service.Name).ToCamel(),
"logicName": fmt.Sprintf("%sLogic", stringx.From(rpc.Name).Title()), "logicName": fmt.Sprintf("%sLogic", stringx.From(rpc.Name).ToCamel()),
"method": parser.CamelCase(rpc.Name), "method": parser.CamelCase(rpc.Name),
"request": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.RequestType)), "request": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.RequestType)),
"response": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)), "response": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)),

@ -53,8 +53,12 @@ func mkdir(ctx *ctx.ProjectContext, proto parser.Proto) (DirContext, error) {
logicDir := filepath.Join(internalDir, "logic") logicDir := filepath.Join(internalDir, "logic")
serverDir := filepath.Join(internalDir, "server") serverDir := filepath.Join(internalDir, "server")
svcDir := filepath.Join(internalDir, "svc") svcDir := filepath.Join(internalDir, "svc")
pbDir := filepath.Join(internalDir, proto.GoPackage) pbDir := filepath.Join(ctx.WorkDir, proto.GoPackage)
callDir := filepath.Join(ctx.WorkDir, strings.ToLower(stringx.From(proto.Service.Name).ToCamel())) callDir := filepath.Join(ctx.WorkDir, strings.ToLower(stringx.From(proto.Service.Name).ToCamel()))
if strings.ToLower(proto.Service.Name) == strings.ToLower(proto.GoPackage) {
callDir = filepath.Join(ctx.WorkDir, strings.ToLower(stringx.From(proto.Service.Name+"_client").ToCamel()))
}
inner[wd] = Dir{ inner[wd] = Dir{
Filename: ctx.WorkDir, Filename: ctx.WorkDir,
Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(ctx.WorkDir, ctx.Dir))), Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(ctx.WorkDir, ctx.Dir))),

@ -6,11 +6,11 @@ option go_package = "go";
import "test_base.proto"; import "test_base.proto";
message TestMessage{ message TestMessage {
base.CommonReq req = 1; base.CommonReq req = 1;
} }
message TestReq{} message TestReq {}
message TestReply{ message TestReply {
base.CommonReply reply = 2; base.CommonReply reply = 2;
} }
@ -20,6 +20,6 @@ enum TestEnum {
female = 2; female = 2;
} }
service TestService{ service TestService {
rpc TestRpc (TestReq)returns(TestReply); rpc TestRpc (TestReq) returns (TestReply);
} }

@ -14,5 +14,5 @@ message StreamResp {
} }
service StreamGreeter { service StreamGreeter {
rpc greet(StreamReq) returns (StreamResp); rpc greet (StreamReq) returns (StreamResp);
} }

@ -14,5 +14,5 @@ message Out {
} }
service StreamGreeter { service StreamGreeter {
rpc greet(In) returns (Out); rpc greet (In) returns (Out);
} }

@ -14,5 +14,5 @@ message StreamResp {
} }
service StreamGreeter { service StreamGreeter {
rpc greet(StreamReq) returns (StreamResp); rpc greet (StreamReq) returns (StreamResp);
} }

@ -0,0 +1,27 @@
// test proto
syntax = "proto3";
package snake_package;
message StreamReq {
string name = 1;
}
message Stream_Resp {
string greet = 1;
}
message lowercase {
string in = 1;
string lower = 2;
}
message CamelCase {
string Camel = 1;
}
service Stream_Greeter {
rpc snake_service(StreamReq) returns (Stream_Resp);
rpc ServiceCamelCase(CamelCase) returns (CamelCase);
rpc servicelowercase(lowercase) returns (lowercase);
}

@ -12,5 +12,6 @@ message StreamResp {
} }
service StreamGreeter { service StreamGreeter {
rpc greet(StreamReq) returns (StreamResp); // greet service
rpc greet (StreamReq) returns (StreamResp);
} }

@ -6,5 +6,5 @@ func GetComment(comment *proto.Comment) string {
if comment == nil { if comment == nil {
return "" return ""
} }
return comment.Message() return "// " + comment.Message()
} }

@ -8,6 +8,6 @@ import "base.proto";
message Reply{} message Reply{}
service TestService{ service TestService {
rpc TestRpcTwo (base.Req)returns(Reply); rpc TestRpcTwo (base.Req) returns (Reply);
} }

@ -8,6 +8,6 @@ import "base.proto";
message Req{} message Req{}
service TestService{ service TestService {
rpc TestRpcTwo (Req)returns(base.Reply); rpc TestRpcTwo (Req) returns (base.Reply);
} }

@ -2,9 +2,10 @@ syntax = "proto3";
package stream; package stream;
option go_package="github.com/tal-tech/go-zero"; option go_package = "github.com/tal-tech/go-zero";
message placeholder{} message placeholder {}
service greet{
rpc hello(placeholder)returns(placeholder); service greet {
rpc hello (placeholder) returns (placeholder);
} }

@ -3,7 +3,8 @@ syntax = "proto3";
package stream; package stream;
message placeholder{} message placeholder {}
service greet{
rpc hello(placeholder)returns(placeholder); service greet {
rpc hello (placeholder) returns (placeholder);
} }
Loading…
Cancel
Save