diff --git a/gateway/config.go b/gateway/config.go index 9c7b1a78..a26d1447 100644 --- a/gateway/config.go +++ b/gateway/config.go @@ -31,6 +31,8 @@ type ( Grpc zrpc.RpcClientConf // ProtoSet is the file of proto set, like hello.pb ProtoSet string `json:",optional"` - Mapping []mapping + // Mapping is the mapping between gateway routes and upstream rpc methods. + // Keep it blank if annotations are added in rpc methods. + Mapping []mapping `json:",optional"` } ) diff --git a/gateway/internal/descriptorsource.go b/gateway/internal/descriptorsource.go index 7925ece4..8e1895bc 100644 --- a/gateway/internal/descriptorsource.go +++ b/gateway/internal/descriptorsource.go @@ -2,19 +2,29 @@ package internal import ( "fmt" + "net/http" + "strings" "github.com/fullstorydev/grpcurl" "github.com/jhump/protoreflect/desc" + "google.golang.org/genproto/googleapis/api/annotations" + "google.golang.org/protobuf/proto" ) +type Method struct { + HttpMethod string + HttpPath string + RpcPath string +} + // GetMethods returns all methods of the given grpcurl.DescriptorSource. -func GetMethods(source grpcurl.DescriptorSource) ([]string, error) { +func GetMethods(source grpcurl.DescriptorSource) ([]Method, error) { svcs, err := source.ListServices() if err != nil { return nil, err } - var methods []string + var methods []Method for _, svc := range svcs { d, err := source.FindSymbol(svc) if err != nil { @@ -25,10 +35,68 @@ func GetMethods(source grpcurl.DescriptorSource) ([]string, error) { case *desc.ServiceDescriptor: svcMethods := val.GetMethods() for _, method := range svcMethods { - methods = append(methods, fmt.Sprintf("%s/%s", svc, method.GetName())) + rpcPath := fmt.Sprintf("%s/%s", svc, method.GetName()) + ext := proto.GetExtension(method.GetMethodOptions(), annotations.E_Http) + if ext == nil { + methods = append(methods, Method{ + RpcPath: rpcPath, + }) + continue + } + + httpExt, ok := ext.(*annotations.HttpRule) + if !ok { + methods = append(methods, Method{ + RpcPath: rpcPath, + }) + continue + } + + switch rule := httpExt.GetPattern().(type) { + case *annotations.HttpRule_Get: + methods = append(methods, Method{ + HttpMethod: http.MethodGet, + HttpPath: adjustHttpPath(rule.Get), + RpcPath: rpcPath, + }) + case *annotations.HttpRule_Post: + methods = append(methods, Method{ + HttpMethod: http.MethodPost, + HttpPath: adjustHttpPath(rule.Post), + RpcPath: rpcPath, + }) + case *annotations.HttpRule_Put: + methods = append(methods, Method{ + HttpMethod: http.MethodPut, + HttpPath: adjustHttpPath(rule.Put), + RpcPath: rpcPath, + }) + case *annotations.HttpRule_Delete: + methods = append(methods, Method{ + HttpMethod: http.MethodDelete, + HttpPath: adjustHttpPath(rule.Delete), + RpcPath: rpcPath, + }) + case *annotations.HttpRule_Patch: + methods = append(methods, Method{ + HttpMethod: http.MethodPatch, + HttpPath: adjustHttpPath(rule.Patch), + RpcPath: rpcPath, + }) + default: + methods = append(methods, Method{ + RpcPath: rpcPath, + }) + } } } } return methods, nil } + +func adjustHttpPath(path string) string { + path = strings.ReplaceAll(path, "{", ":") + path = strings.ReplaceAll(path, "}", "") + return path +} diff --git a/gateway/internal/descriptorsource_test.go b/gateway/internal/descriptorsource_test.go index 8c066077..b9fe893f 100644 --- a/gateway/internal/descriptorsource_test.go +++ b/gateway/internal/descriptorsource_test.go @@ -25,5 +25,9 @@ func TestGetMethods(t *testing.T) { assert.Nil(t, err) methods, err := GetMethods(source) assert.Nil(t, err) - assert.EqualValues(t, []string{"hello.Hello/Ping"}, methods) + assert.EqualValues(t, []Method{ + { + RpcPath: "hello.Hello/Ping", + }, + }, methods) } diff --git a/gateway/server.go b/gateway/server.go index 2ab65908..15d4c3a6 100644 --- a/gateway/server.go +++ b/gateway/server.go @@ -66,11 +66,21 @@ func (s *Server) build() error { return } + resolver := grpcurl.AnyResolverFromDescriptorSource(source) + for _, m := range methods { + if len(m.HttpMethod) > 0 && len(m.HttpPath) > 0 { + writer.Write(rest.Route{ + Method: m.HttpMethod, + Path: m.HttpPath, + Handler: s.buildHandler(source, resolver, cli, m.RpcPath), + }) + } + } + methodSet := make(map[string]struct{}) for _, m := range methods { - methodSet[m] = struct{}{} + methodSet[m.RpcPath] = struct{}{} } - resolver := grpcurl.AnyResolverFromDescriptorSource(source) for _, m := range up.Mapping { if _, ok := methodSet[m.RpcPath]; !ok { cancel(fmt.Errorf("rpc method %s not found", m.RpcPath)) @@ -80,7 +90,7 @@ func (s *Server) build() error { writer.Write(rest.Route{ Method: strings.ToUpper(m.Method), Path: m.Path, - Handler: s.buildHandler(source, resolver, cli, m), + Handler: s.buildHandler(source, resolver, cli, m.RpcPath), }) } }, func(pipe <-chan interface{}, cancel func(error)) { @@ -92,7 +102,7 @@ func (s *Server) build() error { } func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.AnyResolver, - cli zrpc.Client, m mapping) func(http.ResponseWriter, *http.Request) { + cli zrpc.Client, rpcPath string) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { handler := &grpcurl.DefaultEventHandler{ Out: w, @@ -110,7 +120,7 @@ func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.A defer can() w.Header().Set(httpx.ContentType, httpx.JsonContentType) - if err := grpcurl.InvokeRPC(ctx, source, cli.Conn(), m.RpcPath, internal.BuildHeaders(r.Header), + if err := grpcurl.InvokeRPC(ctx, source, cli.Conn(), rpcPath, internal.BuildHeaders(r.Header), handler, parser.Next); err != nil { httpx.Error(w, err) }