feat: refactor gateway code (#3160)

master
MarkJoyMa 2 years ago committed by GitHub
parent d10740f871
commit 9970ff55cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -21,26 +21,21 @@ import (
type ( type (
// Server is a gateway server. // Server is a gateway server.
Server struct { Server struct {
c GatewayConf
*rest.Server *rest.Server
upstreams []*upstream upstreams []Upstream
processHeader func(http.Header) []string processHeader func(http.Header) []string
dialer func(conf zrpc.RpcClientConf) zrpc.Client
} }
// Option defines the method to customize Server. // Option defines the method to customize Server.
Option func(svr *Server) Option func(svr *Server)
upstream struct {
Upstream
client zrpc.Client
}
) )
// MustNewServer creates a new gateway server. // MustNewServer creates a new gateway server.
func MustNewServer(c GatewayConf, opts ...Option) *Server { func MustNewServer(c GatewayConf, opts ...Option) *Server {
svr := &Server{ svr := &Server{
c: c, upstreams: c.Upstreams,
Server: rest.MustNewServer(c.RestConf), Server: rest.MustNewServer(c.RestConf),
} }
for _, opt := range opts { for _, opt := range opts {
opt(svr) opt(svr)
@ -61,23 +56,15 @@ func (s *Server) Stop() {
} }
func (s *Server) build() error { func (s *Server) build() error {
if err := s.buildClient(); err != nil {
return err
}
return s.buildUpstream()
}
func (s *Server) buildClient() error {
if err := s.ensureUpstreamNames(); err != nil { if err := s.ensureUpstreamNames(); err != nil {
return err return err
} }
return mr.MapReduceVoid(func(source chan<- Upstream) { return mr.MapReduceVoid(func(source chan<- Upstream) {
for _, up := range s.c.Upstreams { for _, up := range s.upstreams {
source <- up source <- up
} }
}, func(up Upstream, writer mr.Writer[*upstream], cancel func(error)) { }, func(up Upstream, writer mr.Writer[rest.Route], cancel func(error)) {
target, err := up.Grpc.BuildTarget() target, err := up.Grpc.BuildTarget()
if err != nil { if err != nil {
cancel(err) cancel(err)
@ -85,26 +72,14 @@ func (s *Server) buildClient() error {
} }
up.Name = target up.Name = target
cli := zrpc.MustNewClient(up.Grpc) var cli zrpc.Client
writer.Write(&upstream{ if s.dialer != nil {
Upstream: up, cli = s.dialer(up.Grpc)
client: cli, } else {
}) cli = zrpc.MustNewClient(up.Grpc)
}, func(pipe <-chan *upstream, cancel func(error)) {
for up := range pipe {
s.upstreams = append(s.upstreams, up)
} }
})
}
func (s *Server) buildUpstream() error { source, err := s.createDescriptorSource(cli, up)
return mr.MapReduceVoid(func(source chan<- *upstream) {
for _, up := range s.upstreams {
source <- up
}
}, func(up *upstream, writer mr.Writer[rest.Route], cancel func(error)) {
cli := up.client
source, err := s.createDescriptorSource(cli, up.Upstream)
if err != nil { if err != nil {
cancel(fmt.Errorf("%s: %w", up.Name, err)) cancel(fmt.Errorf("%s: %w", up.Name, err))
return return
@ -191,13 +166,13 @@ func (s *Server) createDescriptorSource(cli zrpc.Client, up Upstream) (grpcurl.D
} }
func (s *Server) ensureUpstreamNames() error { func (s *Server) ensureUpstreamNames() error {
for _, up := range s.c.Upstreams { for i := 0; i < len(s.upstreams); i++ {
target, err := up.Grpc.BuildTarget() target, err := s.upstreams[i].Grpc.BuildTarget()
if err != nil { if err != nil {
return err return err
} }
up.Name = target s.upstreams[i].Name = target
} }
return nil return nil
@ -219,3 +194,10 @@ func WithHeaderProcessor(processHeader func(http.Header) []string) func(*Server)
s.processHeader = processHeader s.processHeader = processHeader
} }
} }
// withDialer sets a dialer to create a gRPC client.
func withDialer(dialer func(conf zrpc.RpcClientConf) zrpc.Client) func(*Server) {
return func(s *Server) {
s.dialer = dialer
}
}

@ -49,39 +49,36 @@ func TestMustNewServer(t *testing.T) {
c.Host = "localhost" c.Host = "localhost"
c.Port = 18881 c.Port = 18881
s := MustNewServer(c) s := MustNewServer(c, withDialer(func(conf zrpc.RpcClientConf) zrpc.Client {
s.upstreams = []*upstream{ return zrpc.MustNewClient(conf, zrpc.WithDialOption(grpc.WithContextDialer(dialer())))
}))
s.upstreams = []Upstream{
{ {
Upstream: Upstream{ Mappings: []RouteMapping{
Mappings: []RouteMapping{ {
{ Method: "get",
Method: "get", Path: "/deposit/:amount",
Path: "/deposit/:amount", RpcPath: "mock.DepositService/Deposit",
RpcPath: "mock.DepositService/Deposit",
},
}, },
}, },
client: zrpc.MustNewClient( Grpc: zrpc.RpcClientConf{
zrpc.RpcClientConf{ Endpoints: []string{"foo"},
Endpoints: []string{"foo"}, Timeout: 1000,
Timeout: 1000, Middlewares: zrpc.ClientMiddlewaresConf{
Middlewares: zrpc.ClientMiddlewaresConf{ Trace: true,
Trace: true, Duration: true,
Duration: true, Prometheus: true,
Prometheus: true, Breaker: true,
Breaker: true, Timeout: true,
Timeout: true,
},
}, },
zrpc.WithDialOption(grpc.WithContextDialer(dialer())), },
),
}, },
} }
assert.NoError(t, s.buildUpstream()) assert.NoError(t, s.build())
go s.Server.Start() go s.Server.Start()
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 200)
resp, err := httpc.Do(context.Background(), http.MethodGet, "http://localhost:18881/deposit/100", nil) resp, err := httpc.Do(context.Background(), http.MethodGet, "http://localhost:18881/deposit/100", nil)
assert.NoError(t, err) assert.NoError(t, err)
@ -91,3 +88,18 @@ func TestMustNewServer(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, http.StatusNotFound, resp.StatusCode) assert.Equal(t, http.StatusNotFound, resp.StatusCode)
} }
func TestServer_ensureUpstreamNames(t *testing.T) {
var s = Server{
upstreams: []Upstream{
{
Grpc: zrpc.RpcClientConf{
Target: "target",
},
},
},
}
assert.NoError(t, s.ensureUpstreamNames())
assert.Equal(t, "target", s.upstreams[0].Name)
}

Loading…
Cancel
Save