feat: refactor gateway code (#3160)

master
MarkJoyMa 2 years ago committed by GitHub
parent d10740f871
commit 9970ff55cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -21,25 +21,20 @@ import (
type ( type (
// Server is a gateway server. // Server is a gateway server.
Server struct { Server struct {
c GatewayConf
*rest.Server *rest.Server
upstreams []*upstream upstreams []Upstream
processHeader func(http.Header) []string processHeader func(http.Header) []string
dialer func(conf zrpc.RpcClientConf) zrpc.Client
} }
// Option defines the method to customize Server. // Option defines the method to customize Server.
Option func(svr *Server) Option func(svr *Server)
upstream struct {
Upstream
client zrpc.Client
}
) )
// MustNewServer creates a new gateway server. // MustNewServer creates a new gateway server.
func MustNewServer(c GatewayConf, opts ...Option) *Server { func MustNewServer(c GatewayConf, opts ...Option) *Server {
svr := &Server{ svr := &Server{
c: c, upstreams: c.Upstreams,
Server: rest.MustNewServer(c.RestConf), Server: rest.MustNewServer(c.RestConf),
} }
for _, opt := range opts { for _, opt := range opts {
@ -61,23 +56,15 @@ func (s *Server) Stop() {
} }
func (s *Server) build() error { func (s *Server) build() error {
if err := s.buildClient(); err != nil {
return err
}
return s.buildUpstream()
}
func (s *Server) buildClient() error {
if err := s.ensureUpstreamNames(); err != nil { if err := s.ensureUpstreamNames(); err != nil {
return err return err
} }
return mr.MapReduceVoid(func(source chan<- Upstream) { return mr.MapReduceVoid(func(source chan<- Upstream) {
for _, up := range s.c.Upstreams { for _, up := range s.upstreams {
source <- up source <- up
} }
}, func(up Upstream, writer mr.Writer[*upstream], cancel func(error)) { }, func(up Upstream, writer mr.Writer[rest.Route], cancel func(error)) {
target, err := up.Grpc.BuildTarget() target, err := up.Grpc.BuildTarget()
if err != nil { if err != nil {
cancel(err) cancel(err)
@ -85,26 +72,14 @@ func (s *Server) buildClient() error {
} }
up.Name = target up.Name = target
cli := zrpc.MustNewClient(up.Grpc) var cli zrpc.Client
writer.Write(&upstream{ if s.dialer != nil {
Upstream: up, cli = s.dialer(up.Grpc)
client: cli, } else {
}) cli = zrpc.MustNewClient(up.Grpc)
}, func(pipe <-chan *upstream, cancel func(error)) {
for up := range pipe {
s.upstreams = append(s.upstreams, up)
} }
})
}
func (s *Server) buildUpstream() error { source, err := s.createDescriptorSource(cli, up)
return mr.MapReduceVoid(func(source chan<- *upstream) {
for _, up := range s.upstreams {
source <- up
}
}, func(up *upstream, writer mr.Writer[rest.Route], cancel func(error)) {
cli := up.client
source, err := s.createDescriptorSource(cli, up.Upstream)
if err != nil { if err != nil {
cancel(fmt.Errorf("%s: %w", up.Name, err)) cancel(fmt.Errorf("%s: %w", up.Name, err))
return return
@ -191,13 +166,13 @@ func (s *Server) createDescriptorSource(cli zrpc.Client, up Upstream) (grpcurl.D
} }
func (s *Server) ensureUpstreamNames() error { func (s *Server) ensureUpstreamNames() error {
for _, up := range s.c.Upstreams { for i := 0; i < len(s.upstreams); i++ {
target, err := up.Grpc.BuildTarget() target, err := s.upstreams[i].Grpc.BuildTarget()
if err != nil { if err != nil {
return err return err
} }
up.Name = target s.upstreams[i].Name = target
} }
return nil return nil
@ -219,3 +194,10 @@ func WithHeaderProcessor(processHeader func(http.Header) []string) func(*Server)
s.processHeader = processHeader s.processHeader = processHeader
} }
} }
// withDialer sets a dialer to create a gRPC client.
func withDialer(dialer func(conf zrpc.RpcClientConf) zrpc.Client) func(*Server) {
return func(s *Server) {
s.dialer = dialer
}
}

@ -49,10 +49,11 @@ func TestMustNewServer(t *testing.T) {
c.Host = "localhost" c.Host = "localhost"
c.Port = 18881 c.Port = 18881
s := MustNewServer(c) s := MustNewServer(c, withDialer(func(conf zrpc.RpcClientConf) zrpc.Client {
s.upstreams = []*upstream{ return zrpc.MustNewClient(conf, zrpc.WithDialOption(grpc.WithContextDialer(dialer())))
}))
s.upstreams = []Upstream{
{ {
Upstream: Upstream{
Mappings: []RouteMapping{ Mappings: []RouteMapping{
{ {
Method: "get", Method: "get",
@ -60,9 +61,7 @@ func TestMustNewServer(t *testing.T) {
RpcPath: "mock.DepositService/Deposit", RpcPath: "mock.DepositService/Deposit",
}, },
}, },
}, Grpc: zrpc.RpcClientConf{
client: zrpc.MustNewClient(
zrpc.RpcClientConf{
Endpoints: []string{"foo"}, Endpoints: []string{"foo"},
Timeout: 1000, Timeout: 1000,
Middlewares: zrpc.ClientMiddlewaresConf{ Middlewares: zrpc.ClientMiddlewaresConf{
@ -73,15 +72,13 @@ func TestMustNewServer(t *testing.T) {
Timeout: true, Timeout: true,
}, },
}, },
zrpc.WithDialOption(grpc.WithContextDialer(dialer())),
),
}, },
} }
assert.NoError(t, s.buildUpstream()) assert.NoError(t, s.build())
go s.Server.Start() go s.Server.Start()
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 200)
resp, err := httpc.Do(context.Background(), http.MethodGet, "http://localhost:18881/deposit/100", nil) resp, err := httpc.Do(context.Background(), http.MethodGet, "http://localhost:18881/deposit/100", nil)
assert.NoError(t, err) assert.NoError(t, err)
@ -91,3 +88,18 @@ func TestMustNewServer(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, http.StatusNotFound, resp.StatusCode) assert.Equal(t, http.StatusNotFound, resp.StatusCode)
} }
func TestServer_ensureUpstreamNames(t *testing.T) {
var s = Server{
upstreams: []Upstream{
{
Grpc: zrpc.RpcClientConf{
Target: "target",
},
},
},
}
assert.NoError(t, s.ensureUpstreamNames())
assert.Equal(t, "target", s.upstreams[0].Name)
}

Loading…
Cancel
Save