diff --git a/gateway/server.go b/gateway/server.go index e0b85e81..a525a228 100644 --- a/gateway/server.go +++ b/gateway/server.go @@ -21,26 +21,21 @@ import ( type ( // Server is a gateway server. Server struct { - c GatewayConf *rest.Server - upstreams []*upstream + upstreams []Upstream processHeader func(http.Header) []string + dialer func(conf zrpc.RpcClientConf) zrpc.Client } // Option defines the method to customize Server. Option func(svr *Server) - - upstream struct { - Upstream - client zrpc.Client - } ) // MustNewServer creates a new gateway server. func MustNewServer(c GatewayConf, opts ...Option) *Server { svr := &Server{ - c: c, - Server: rest.MustNewServer(c.RestConf), + upstreams: c.Upstreams, + Server: rest.MustNewServer(c.RestConf), } for _, opt := range opts { opt(svr) @@ -61,23 +56,15 @@ func (s *Server) Stop() { } func (s *Server) build() error { - if err := s.buildClient(); err != nil { - return err - } - - return s.buildUpstream() -} - -func (s *Server) buildClient() error { if err := s.ensureUpstreamNames(); err != nil { return err } return mr.MapReduceVoid(func(source chan<- Upstream) { - for _, up := range s.c.Upstreams { + for _, up := range s.upstreams { source <- up } - }, func(up Upstream, writer mr.Writer[*upstream], cancel func(error)) { + }, func(up Upstream, writer mr.Writer[rest.Route], cancel func(error)) { target, err := up.Grpc.BuildTarget() if err != nil { cancel(err) @@ -85,26 +72,14 @@ func (s *Server) buildClient() error { } up.Name = target - cli := zrpc.MustNewClient(up.Grpc) - writer.Write(&upstream{ - Upstream: up, - client: cli, - }) - }, func(pipe <-chan *upstream, cancel func(error)) { - for up := range pipe { - s.upstreams = append(s.upstreams, up) + var cli zrpc.Client + if s.dialer != nil { + cli = s.dialer(up.Grpc) + } else { + cli = zrpc.MustNewClient(up.Grpc) } - }) -} -func (s *Server) buildUpstream() error { - return mr.MapReduceVoid(func(source chan<- *upstream) { - for _, up := range s.upstreams { - source <- up - } - }, func(up *upstream, writer mr.Writer[rest.Route], cancel func(error)) { - cli := up.client - source, err := s.createDescriptorSource(cli, up.Upstream) + source, err := s.createDescriptorSource(cli, up) if err != nil { cancel(fmt.Errorf("%s: %w", up.Name, err)) return @@ -191,13 +166,13 @@ func (s *Server) createDescriptorSource(cli zrpc.Client, up Upstream) (grpcurl.D } func (s *Server) ensureUpstreamNames() error { - for _, up := range s.c.Upstreams { - target, err := up.Grpc.BuildTarget() + for i := 0; i < len(s.upstreams); i++ { + target, err := s.upstreams[i].Grpc.BuildTarget() if err != nil { return err } - up.Name = target + s.upstreams[i].Name = target } return nil @@ -219,3 +194,10 @@ func WithHeaderProcessor(processHeader func(http.Header) []string) func(*Server) s.processHeader = processHeader } } + +// withDialer sets a dialer to create a gRPC client. +func withDialer(dialer func(conf zrpc.RpcClientConf) zrpc.Client) func(*Server) { + return func(s *Server) { + s.dialer = dialer + } +} diff --git a/gateway/server_test.go b/gateway/server_test.go index c9d168d9..6d21c2f1 100644 --- a/gateway/server_test.go +++ b/gateway/server_test.go @@ -49,39 +49,36 @@ func TestMustNewServer(t *testing.T) { c.Host = "localhost" c.Port = 18881 - s := MustNewServer(c) - s.upstreams = []*upstream{ + s := MustNewServer(c, withDialer(func(conf zrpc.RpcClientConf) zrpc.Client { + return zrpc.MustNewClient(conf, zrpc.WithDialOption(grpc.WithContextDialer(dialer()))) + })) + s.upstreams = []Upstream{ { - Upstream: Upstream{ - Mappings: []RouteMapping{ - { - Method: "get", - Path: "/deposit/:amount", - RpcPath: "mock.DepositService/Deposit", - }, + Mappings: []RouteMapping{ + { + Method: "get", + Path: "/deposit/:amount", + RpcPath: "mock.DepositService/Deposit", }, }, - client: zrpc.MustNewClient( - zrpc.RpcClientConf{ - Endpoints: []string{"foo"}, - Timeout: 1000, - Middlewares: zrpc.ClientMiddlewaresConf{ - Trace: true, - Duration: true, - Prometheus: true, - Breaker: true, - Timeout: true, - }, + Grpc: zrpc.RpcClientConf{ + Endpoints: []string{"foo"}, + Timeout: 1000, + Middlewares: zrpc.ClientMiddlewaresConf{ + Trace: true, + Duration: true, + Prometheus: true, + Breaker: true, + Timeout: true, }, - zrpc.WithDialOption(grpc.WithContextDialer(dialer())), - ), + }, }, } - assert.NoError(t, s.buildUpstream()) + assert.NoError(t, s.build()) go s.Server.Start() - time.Sleep(time.Millisecond * 100) + time.Sleep(time.Millisecond * 200) resp, err := httpc.Do(context.Background(), http.MethodGet, "http://localhost:18881/deposit/100", nil) assert.NoError(t, err) @@ -91,3 +88,18 @@ func TestMustNewServer(t *testing.T) { assert.NoError(t, err) assert.Equal(t, http.StatusNotFound, resp.StatusCode) } + +func TestServer_ensureUpstreamNames(t *testing.T) { + var s = Server{ + upstreams: []Upstream{ + { + Grpc: zrpc.RpcClientConf{ + Target: "target", + }, + }, + }, + } + + assert.NoError(t, s.ensureUpstreamNames()) + assert.Equal(t, "target", s.upstreams[0].Name) +}