diff --git a/go.mod b/go.mod index 369674db..125ef6e1 100644 --- a/go.mod +++ b/go.mod @@ -54,4 +54,5 @@ require ( k8s.io/api v0.20.10 k8s.io/apimachinery v0.20.10 k8s.io/client-go v0.20.10 + k8s.io/utils v0.0.0-20201110183641-67b214c5f920 ) diff --git a/rest/config.go b/rest/config.go index a806b723..972c87c1 100644 --- a/rest/config.go +++ b/rest/config.go @@ -35,7 +35,7 @@ type ( KeyFile string `json:",optional"` Verbose bool `json:",optional"` MaxConns int `json:",default=10000"` - MaxBytes int64 `json:",default=1048576,range=[0:33554432]"` + MaxBytes int64 `json:",default=1048576"` // milliseconds Timeout int64 `json:",default=3000"` CpuThreshold int64 `json:",default=900,range=[0:1000]"` diff --git a/rest/engine.go b/rest/engine.go index 75fcbc9b..410f5acf 100644 --- a/rest/engine.go +++ b/rest/engine.go @@ -47,58 +47,63 @@ func newEngine(c RestConf) *engine { return srv } -func (s *engine) AddRoutes(r featuredRoutes) { - s.routes = append(s.routes, r) +func (ng *engine) AddRoutes(r featuredRoutes) { + ng.routes = append(ng.routes, r) } -func (s *engine) SetUnauthorizedCallback(callback handler.UnauthorizedCallback) { - s.unauthorizedCallback = callback +func (ng *engine) SetUnauthorizedCallback(callback handler.UnauthorizedCallback) { + ng.unauthorizedCallback = callback } -func (s *engine) SetUnsignedCallback(callback handler.UnsignedCallback) { - s.unsignedCallback = callback +func (ng *engine) SetUnsignedCallback(callback handler.UnsignedCallback) { + ng.unsignedCallback = callback } -func (s *engine) Start() error { - return s.StartWithRouter(router.NewRouter()) +func (ng *engine) Start() error { + return ng.StartWithRouter(router.NewRouter()) } -func (s *engine) StartWithRouter(router httpx.Router) error { - if err := s.bindRoutes(router); err != nil { +func (ng *engine) StartWithRouter(router httpx.Router) error { + if err := ng.bindRoutes(router); err != nil { return err } - if len(s.conf.CertFile) == 0 && len(s.conf.KeyFile) == 0 { - return internal.StartHttp(s.conf.Host, s.conf.Port, router) + if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 { + return internal.StartHttp(ng.conf.Host, ng.conf.Port, router) } - return internal.StartHttps(s.conf.Host, s.conf.Port, s.conf.CertFile, s.conf.KeyFile, s.tlsConfig, router) + return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile, + ng.conf.KeyFile, router, func(srv *http.Server) { + if ng.tlsConfig != nil { + srv.TLSConfig = ng.tlsConfig + } + }) } -func (s *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain, +func (ng *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain, verifier func(alice.Chain) alice.Chain) alice.Chain { if fr.jwt.enabled { if len(fr.jwt.prevSecret) == 0 { chain = chain.Append(handler.Authorize(fr.jwt.secret, - handler.WithUnauthorizedCallback(s.unauthorizedCallback))) + handler.WithUnauthorizedCallback(ng.unauthorizedCallback))) } else { chain = chain.Append(handler.Authorize(fr.jwt.secret, handler.WithPrevSecret(fr.jwt.prevSecret), - handler.WithUnauthorizedCallback(s.unauthorizedCallback))) + handler.WithUnauthorizedCallback(ng.unauthorizedCallback))) } } return verifier(chain) } -func (s *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error { - verifier, err := s.signatureVerifier(fr.signature) +func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error { + verifier, err := ng.signatureVerifier(fr.signature) if err != nil { return err } for _, route := range fr.routes { - if err := s.bindRoute(fr, router, metrics, route, verifier); err != nil { + if err := ng.bindRoute(fr, router, metrics, route, verifier); err != nil { return err } } @@ -106,24 +111,24 @@ func (s *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metr return nil } -func (s *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics, +func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics, route Route, verifier func(chain alice.Chain) alice.Chain) error { chain := alice.New( - handler.TracingHandler(s.conf.Name, route.Path), - s.getLogHandler(), + handler.TracingHandler(ng.conf.Name, route.Path), + ng.getLogHandler(), handler.PrometheusHandler(route.Path), - handler.MaxConns(s.conf.MaxConns), + handler.MaxConns(ng.conf.MaxConns), handler.BreakerHandler(route.Method, route.Path, metrics), - handler.SheddingHandler(s.getShedder(fr.priority), metrics), - handler.TimeoutHandler(time.Duration(s.conf.Timeout)*time.Millisecond), + handler.SheddingHandler(ng.getShedder(fr.priority), metrics), + handler.TimeoutHandler(time.Duration(ng.conf.Timeout)*time.Millisecond), handler.RecoverHandler, handler.MetricHandler(metrics), - handler.MaxBytesHandler(s.conf.MaxBytes), + handler.MaxBytesHandler(ng.conf.MaxBytes), handler.GunzipHandler, ) - chain = s.appendAuthHandler(fr, chain, verifier) + chain = ng.appendAuthHandler(fr, chain, verifier) - for _, middleware := range s.middlewares { + for _, middleware := range ng.middlewares { chain = chain.Append(convertMiddleware(middleware)) } handle := chain.ThenFunc(route.Handler) @@ -131,11 +136,11 @@ func (s *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat return router.Handle(route.Method, route.Path, handle) } -func (s *engine) bindRoutes(router httpx.Router) error { - metrics := s.createMetrics() +func (ng *engine) bindRoutes(router httpx.Router) error { + metrics := ng.createMetrics() - for _, fr := range s.routes { - if err := s.bindFeaturedRoutes(router, fr, metrics); err != nil { + for _, fr := range ng.routes { + if err := ng.bindFeaturedRoutes(router, fr, metrics); err != nil { return err } } @@ -143,35 +148,39 @@ func (s *engine) bindRoutes(router httpx.Router) error { return nil } -func (s *engine) createMetrics() *stat.Metrics { +func (ng *engine) createMetrics() *stat.Metrics { var metrics *stat.Metrics - if len(s.conf.Name) > 0 { - metrics = stat.NewMetrics(s.conf.Name) + if len(ng.conf.Name) > 0 { + metrics = stat.NewMetrics(ng.conf.Name) } else { - metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", s.conf.Host, s.conf.Port)) + metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", ng.conf.Host, ng.conf.Port)) } return metrics } -func (s *engine) getLogHandler() func(http.Handler) http.Handler { - if s.conf.Verbose { +func (ng *engine) getLogHandler() func(http.Handler) http.Handler { + if ng.conf.Verbose { return handler.DetailedLogHandler } return handler.LogHandler } -func (s *engine) getShedder(priority bool) load.Shedder { - if priority && s.priorityShedder != nil { - return s.priorityShedder +func (ng *engine) getShedder(priority bool) load.Shedder { + if priority && ng.priorityShedder != nil { + return ng.priorityShedder } - return s.shedder + return ng.shedder } -func (s *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) { +func (ng *engine) setTlsConfig(cfg *tls.Config) { + ng.tlsConfig = cfg +} + +func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) { if !signature.enabled { return func(chain alice.Chain) alice.Chain { return chain @@ -201,9 +210,9 @@ func (s *engine) signatureVerifier(signature signatureSetting) (func(chain alice } return func(chain alice.Chain) alice.Chain { - if s.unsignedCallback != nil { + if ng.unsignedCallback != nil { return chain.Append(handler.ContentSecurityHandler( - decrypters, signature.Expiry, signature.Strict, s.unsignedCallback)) + decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback)) } return chain.Append(handler.ContentSecurityHandler( @@ -211,8 +220,8 @@ func (s *engine) signatureVerifier(signature signatureSetting) (func(chain alice }, nil } -func (s *engine) use(middleware Middleware) { - s.middlewares = append(s.middlewares, middleware) +func (ng *engine) use(middleware Middleware) { + ng.middlewares = append(ng.middlewares, middleware) } func convertMiddleware(ware Middleware) func(http.Handler) http.Handler { diff --git a/rest/internal/starter.go b/rest/internal/starter.go index 5af8b374..0c549ae7 100644 --- a/rest/internal/starter.go +++ b/rest/internal/starter.go @@ -2,38 +2,46 @@ package internal import ( "context" - "crypto/tls" "fmt" "net/http" + "github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/core/proc" ) +// StartOption defines the method to customize http.Server. +type StartOption func(srv *http.Server) + // StartHttp starts a http server. -func StartHttp(host string, port int, handler http.Handler) error { - return start(host, port, handler, nil, func(srv *http.Server) error { +func StartHttp(host string, port int, handler http.Handler, opts ...StartOption) error { + return start(host, port, handler, func(srv *http.Server) error { return srv.ListenAndServe() - }) + }, opts...) } // StartHttps starts a https server. -func StartHttps(host string, port int, certFile, keyFile string, tlsConfig *tls.Config, handler http.Handler) error { - return start(host, port, handler, tlsConfig, func(srv *http.Server) error { +func StartHttps(host string, port int, certFile, keyFile string, handler http.Handler, + opts ...StartOption) error { + return start(host, port, handler, func(srv *http.Server) error { // certFile and keyFile are set in buildHttpsServer return srv.ListenAndServeTLS(certFile, keyFile) - }) + }, opts...) } -func start(host string, port int, handler http.Handler, tlsConfig *tls.Config, run func(srv *http.Server) error) (err error) { +func start(host string, port int, handler http.Handler, run func(srv *http.Server) error, + opts ...StartOption) (err error) { server := &http.Server{ Addr: fmt.Sprintf("%s:%d", host, port), Handler: handler, } - if tlsConfig != nil { - server.TLSConfig = tlsConfig + for _, opt := range opts { + opt(server) } + waitForCalled := proc.AddWrapUpListener(func() { - server.Shutdown(context.Background()) + if e := server.Shutdown(context.Background()); err != nil { + logx.Error(e) + } }) defer func() { if err == http.ErrServerClosed { diff --git a/rest/server.go b/rest/server.go index 4a7b3605..0046712c 100644 --- a/rest/server.go +++ b/rest/server.go @@ -48,8 +48,8 @@ func NewServer(c RestConf, opts ...RunOption) (*Server, error) { server := &Server{ ngin: newEngine(c), opts: runOptions{ - start: func(srv *engine) error { - return srv.Start() + start: func(ng *engine) error { + return ng.Start() }, }, } @@ -171,8 +171,8 @@ func WithPriority() RouteOption { // WithRouter returns a RunOption that make server run with given router. func WithRouter(router httpx.Router) RunOption { return func(server *Server) { - server.opts.start = func(srv *engine) error { - return srv.StartWithRouter(router) + server.opts.start = func(ng *engine) error { + return ng.StartWithRouter(router) } } } @@ -187,26 +187,24 @@ func WithSignature(signature SignatureConf) RouteOption { } } -// WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set. -func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption { - return func(engine *Server) { - engine.ngin.SetUnauthorizedCallback(callback) +// WithTLSConfig returns a RunOption that with given tls config. +func WithTLSConfig(cfg *tls.Config) RunOption { + return func(srv *Server) { + srv.ngin.setTlsConfig(cfg) } } -// WithTLSConfig returns a RunOption that with given tls config. -func WithTLSConfig(cipherSuites []uint16) RunOption { - return func(engine *Server) { - engine.ngin.tlsConfig = &tls.Config{ - CipherSuites: cipherSuites, - } +// WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set. +func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption { + return func(srv *Server) { + srv.ngin.SetUnauthorizedCallback(callback) } } // WithUnsignedCallback returns a RunOption that with given unsigned callback set. func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption { - return func(engine *Server) { - engine.ngin.SetUnsignedCallback(callback) + return func(srv *Server) { + srv.ngin.SetUnsignedCallback(callback) } } diff --git a/rest/server_test.go b/rest/server_test.go index 3958e796..811e9ee4 100644 --- a/rest/server_test.go +++ b/rest/server_test.go @@ -227,8 +227,10 @@ Port: 54321 var cnf RestConf assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf)) - testConfig := []uint16{ - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + testConfig := &tls.Config{ + CipherSuites: []uint16{ + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + }, } testCases := []struct { @@ -239,7 +241,7 @@ Port: 54321 { c: cnf, opts: []RunOption{WithTLSConfig(testConfig)}, - res: &tls.Config{CipherSuites: testConfig}, + res: testConfig, }, { c: cnf,