diff --git a/rest/engine.go b/rest/engine.go new file mode 100644 index 00000000..69496868 --- /dev/null +++ b/rest/engine.go @@ -0,0 +1,214 @@ +package rest + +import ( + "errors" + "fmt" + "net/http" + "time" + + "github.com/justinas/alice" + "github.com/tal-tech/go-zero/core/codec" + "github.com/tal-tech/go-zero/core/load" + "github.com/tal-tech/go-zero/core/stat" + "github.com/tal-tech/go-zero/rest/handler" + "github.com/tal-tech/go-zero/rest/httpx" + "github.com/tal-tech/go-zero/rest/internal" + "github.com/tal-tech/go-zero/rest/router" +) + +// use 1000m to represent 100% +const topCpuUsage = 1000 + +var ErrSignatureConfig = errors.New("bad config for Signature") + +type engine struct { + conf RestConf + routes []featuredRoutes + unauthorizedCallback handler.UnauthorizedCallback + unsignedCallback handler.UnsignedCallback + middlewares []Middleware + shedder load.Shedder + priorityShedder load.Shedder +} + +func newEngine(c RestConf) *engine { + srv := &engine{ + conf: c, + } + if c.CpuThreshold > 0 { + srv.shedder = load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold)) + srv.priorityShedder = load.NewAdaptiveShedder(load.WithCpuThreshold( + (c.CpuThreshold + topCpuUsage) >> 1)) + } + + return srv +} + +func (s *engine) AddRoutes(r featuredRoutes) { + s.routes = append(s.routes, r) +} + +func (s *engine) SetUnauthorizedCallback(callback handler.UnauthorizedCallback) { + s.unauthorizedCallback = callback +} + +func (s *engine) SetUnsignedCallback(callback handler.UnsignedCallback) { + s.unsignedCallback = callback +} + +func (s *engine) Start() error { + return s.StartWithRouter(router.NewPatRouter()) +} + +func (s *engine) StartWithRouter(router httpx.Router) error { + if err := s.bindRoutes(router); err != nil { + return err + } + + return internal.StartHttp(s.conf.Host, s.conf.Port, router) +} + +func (s *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain, + verifier func(alice.Chain) alice.Chain) alice.Chain { + if fr.jwt.enabled { + if len(fr.jwt.prevSecret) == 0 { + chain = chain.Append(handler.Authorize(fr.jwt.secret, + handler.WithUnauthorizedCallback(s.unauthorizedCallback))) + } else { + chain = chain.Append(handler.Authorize(fr.jwt.secret, + handler.WithPrevSecret(fr.jwt.prevSecret), + handler.WithUnauthorizedCallback(s.unauthorizedCallback))) + } + } + + return verifier(chain) +} + +func (s *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error { + verifier, err := s.signatureVerifier(fr.signature) + if err != nil { + return err + } + + for _, route := range fr.routes { + if err := s.bindRoute(fr, router, metrics, route, verifier); err != nil { + return err + } + } + + return nil +} + +func (s *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics, + route Route, verifier func(chain alice.Chain) alice.Chain) error { + chain := alice.New( + handler.TracingHandler, + s.getLogHandler(), + handler.MaxConns(s.conf.MaxConns), + handler.BreakerHandler(route.Method, route.Path, metrics), + handler.SheddingHandler(s.getShedder(fr.priority), metrics), + handler.TimeoutHandler(time.Duration(s.conf.Timeout)*time.Millisecond), + handler.RecoverHandler, + handler.MetricHandler(metrics), + handler.PromMetricHandler(route.Path), + handler.MaxBytesHandler(s.conf.MaxBytes), + handler.GunzipHandler, + ) + chain = s.appendAuthHandler(fr, chain, verifier) + + for _, middleware := range s.middlewares { + chain = chain.Append(convertMiddleware(middleware)) + } + handle := chain.ThenFunc(route.Handler) + + return router.Handle(route.Method, route.Path, handle) +} + +func (s *engine) bindRoutes(router httpx.Router) error { + metrics := s.createMetrics() + + for _, fr := range s.routes { + if err := s.bindFeaturedRoutes(router, fr, metrics); err != nil { + return err + } + } + + return nil +} + +func (s *engine) createMetrics() *stat.Metrics { + var metrics *stat.Metrics + + if len(s.conf.Name) > 0 { + metrics = stat.NewMetrics(s.conf.Name) + } else { + metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", s.conf.Host, s.conf.Port)) + } + + return metrics +} + +func (s *engine) getLogHandler() func(http.Handler) http.Handler { + if s.conf.Verbose { + return handler.DetailedLogHandler + } else { + return handler.LogHandler + } +} + +func (s *engine) getShedder(priority bool) load.Shedder { + if priority && s.priorityShedder != nil { + return s.priorityShedder + } + return s.shedder +} + +func (s *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) { + if !signature.enabled { + return func(chain alice.Chain) alice.Chain { + return chain + }, nil + } + + if len(signature.PrivateKeys) == 0 { + if signature.Strict { + return nil, ErrSignatureConfig + } else { + return func(chain alice.Chain) alice.Chain { + return chain + }, nil + } + } + + decrypters := make(map[string]codec.RsaDecrypter) + for _, key := range signature.PrivateKeys { + fingerprint := key.Fingerprint + file := key.KeyFile + decrypter, err := codec.NewRsaDecrypter(file) + if err != nil { + return nil, err + } + + decrypters[fingerprint] = decrypter + } + + return func(chain alice.Chain) alice.Chain { + if s.unsignedCallback != nil { + return chain.Append(handler.ContentSecurityHandler( + decrypters, signature.Expiry, signature.Strict, s.unsignedCallback)) + } else { + return chain.Append(handler.ContentSecurityHandler( + decrypters, signature.Expiry, signature.Strict)) + } + }, nil +} + +func (s *engine) use(middleware Middleware) { + s.middlewares = append(s.middlewares, middleware) +} + +func convertMiddleware(ware Middleware) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(ware(next.ServeHTTP)) + } +} diff --git a/rest/ngin.go b/rest/ngin.go deleted file mode 100644 index 64d6f6e4..00000000 --- a/rest/ngin.go +++ /dev/null @@ -1,170 +0,0 @@ -package rest - -import ( - "log" - "net/http" - - "github.com/tal-tech/go-zero/core/logx" - "github.com/tal-tech/go-zero/rest/handler" - "github.com/tal-tech/go-zero/rest/httpx" -) - -type ( - runOptions struct { - start func(*engine) error - } - - RunOption func(*Server) - - Server struct { - ngin *engine - opts runOptions - } -) - -func MustNewServer(c RestConf, opts ...RunOption) *Server { - engine, err := NewServer(c, opts...) - if err != nil { - log.Fatal(err) - } - - return engine -} - -func NewServer(c RestConf, opts ...RunOption) (*Server, error) { - if err := c.SetUp(); err != nil { - return nil, err - } - - server := &Server{ - ngin: newEngine(c), - opts: runOptions{ - start: func(srv *engine) error { - return srv.Start() - }, - }, - } - - for _, opt := range opts { - opt(server) - } - - return server, nil -} - -func (e *Server) AddRoutes(rs []Route, opts ...RouteOption) { - r := featuredRoutes{ - routes: rs, - } - for _, opt := range opts { - opt(&r) - } - e.ngin.AddRoutes(r) -} - -func (e *Server) AddRoute(r Route, opts ...RouteOption) { - e.AddRoutes([]Route{r}, opts...) -} - -func (e *Server) Start() { - handleError(e.opts.start(e.ngin)) -} - -func (e *Server) Stop() { - logx.Close() -} - -func (e *Server) Use(middleware Middleware) { - e.ngin.use(middleware) -} - -func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware { - return func(handle http.HandlerFunc) http.HandlerFunc { - return handler(handle).ServeHTTP - } -} - -func WithJwt(secret string) RouteOption { - return func(r *featuredRoutes) { - validateSecret(secret) - r.jwt.enabled = true - r.jwt.secret = secret - } -} - -func WithJwtTransition(secret, prevSecret string) RouteOption { - return func(r *featuredRoutes) { - // why not validate prevSecret, because prevSecret is an already used one, - // even it not meet our requirement, we still need to allow the transition. - validateSecret(secret) - r.jwt.enabled = true - r.jwt.secret = secret - r.jwt.prevSecret = prevSecret - } -} - -func WithMiddleware(middleware Middleware, rs ...Route) []Route { - routes := make([]Route, len(rs)) - - for i := range rs { - route := rs[i] - routes[i] = Route{ - Method: route.Method, - Path: route.Path, - Handler: middleware(route.Handler), - } - } - - return routes -} - -func WithPriority() RouteOption { - return func(r *featuredRoutes) { - r.priority = true - } -} - -func WithRouter(router httpx.Router) RunOption { - return func(server *Server) { - server.opts.start = func(srv *engine) error { - return srv.StartWithRouter(router) - } - } -} - -func WithSignature(signature SignatureConf) RouteOption { - return func(r *featuredRoutes) { - r.signature.enabled = true - r.signature.Strict = signature.Strict - r.signature.Expiry = signature.Expiry - r.signature.PrivateKeys = signature.PrivateKeys - } -} - -func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption { - return func(engine *Server) { - engine.ngin.SetUnauthorizedCallback(callback) - } -} - -func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption { - return func(engine *Server) { - engine.ngin.SetUnsignedCallback(callback) - } -} - -func handleError(err error) { - // ErrServerClosed means the server is closed manually - if err == nil || err == http.ErrServerClosed { - return - } - - logx.Error(err) - panic(err) -} - -func validateSecret(secret string) { - if len(secret) < 8 { - panic("secret's length can't be less than 8") - } -} diff --git a/rest/server.go b/rest/server.go index 69496868..64d6f6e4 100644 --- a/rest/server.go +++ b/rest/server.go @@ -1,214 +1,170 @@ package rest import ( - "errors" - "fmt" + "log" "net/http" - "time" - "github.com/justinas/alice" - "github.com/tal-tech/go-zero/core/codec" - "github.com/tal-tech/go-zero/core/load" - "github.com/tal-tech/go-zero/core/stat" + "github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/rest/handler" "github.com/tal-tech/go-zero/rest/httpx" - "github.com/tal-tech/go-zero/rest/internal" - "github.com/tal-tech/go-zero/rest/router" ) -// use 1000m to represent 100% -const topCpuUsage = 1000 +type ( + runOptions struct { + start func(*engine) error + } -var ErrSignatureConfig = errors.New("bad config for Signature") + RunOption func(*Server) -type engine struct { - conf RestConf - routes []featuredRoutes - unauthorizedCallback handler.UnauthorizedCallback - unsignedCallback handler.UnsignedCallback - middlewares []Middleware - shedder load.Shedder - priorityShedder load.Shedder + Server struct { + ngin *engine + opts runOptions + } +) + +func MustNewServer(c RestConf, opts ...RunOption) *Server { + engine, err := NewServer(c, opts...) + if err != nil { + log.Fatal(err) + } + + return engine } -func newEngine(c RestConf) *engine { - srv := &engine{ - conf: c, +func NewServer(c RestConf, opts ...RunOption) (*Server, error) { + if err := c.SetUp(); err != nil { + return nil, err } - if c.CpuThreshold > 0 { - srv.shedder = load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold)) - srv.priorityShedder = load.NewAdaptiveShedder(load.WithCpuThreshold( - (c.CpuThreshold + topCpuUsage) >> 1)) + + server := &Server{ + ngin: newEngine(c), + opts: runOptions{ + start: func(srv *engine) error { + return srv.Start() + }, + }, } - return srv -} + for _, opt := range opts { + opt(server) + } -func (s *engine) AddRoutes(r featuredRoutes) { - s.routes = append(s.routes, r) + return server, nil } -func (s *engine) SetUnauthorizedCallback(callback handler.UnauthorizedCallback) { - s.unauthorizedCallback = callback +func (e *Server) AddRoutes(rs []Route, opts ...RouteOption) { + r := featuredRoutes{ + routes: rs, + } + for _, opt := range opts { + opt(&r) + } + e.ngin.AddRoutes(r) } -func (s *engine) SetUnsignedCallback(callback handler.UnsignedCallback) { - s.unsignedCallback = callback +func (e *Server) AddRoute(r Route, opts ...RouteOption) { + e.AddRoutes([]Route{r}, opts...) } -func (s *engine) Start() error { - return s.StartWithRouter(router.NewPatRouter()) +func (e *Server) Start() { + handleError(e.opts.start(e.ngin)) } -func (s *engine) StartWithRouter(router httpx.Router) error { - if err := s.bindRoutes(router); err != nil { - return err - } - - return internal.StartHttp(s.conf.Host, s.conf.Port, router) +func (e *Server) Stop() { + logx.Close() } -func (s *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain, - verifier func(alice.Chain) alice.Chain) alice.Chain { - if fr.jwt.enabled { - if len(fr.jwt.prevSecret) == 0 { - chain = chain.Append(handler.Authorize(fr.jwt.secret, - handler.WithUnauthorizedCallback(s.unauthorizedCallback))) - } else { - chain = chain.Append(handler.Authorize(fr.jwt.secret, - handler.WithPrevSecret(fr.jwt.prevSecret), - handler.WithUnauthorizedCallback(s.unauthorizedCallback))) - } - } - - return verifier(chain) +func (e *Server) Use(middleware Middleware) { + e.ngin.use(middleware) } -func (s *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error { - verifier, err := s.signatureVerifier(fr.signature) - if err != nil { - return err +func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware { + return func(handle http.HandlerFunc) http.HandlerFunc { + return handler(handle).ServeHTTP } +} - for _, route := range fr.routes { - if err := s.bindRoute(fr, router, metrics, route, verifier); err != nil { - return err - } +func WithJwt(secret string) RouteOption { + return func(r *featuredRoutes) { + validateSecret(secret) + r.jwt.enabled = true + r.jwt.secret = secret } - - return nil } -func (s *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics, - route Route, verifier func(chain alice.Chain) alice.Chain) error { - chain := alice.New( - handler.TracingHandler, - s.getLogHandler(), - handler.MaxConns(s.conf.MaxConns), - handler.BreakerHandler(route.Method, route.Path, metrics), - handler.SheddingHandler(s.getShedder(fr.priority), metrics), - handler.TimeoutHandler(time.Duration(s.conf.Timeout)*time.Millisecond), - handler.RecoverHandler, - handler.MetricHandler(metrics), - handler.PromMetricHandler(route.Path), - handler.MaxBytesHandler(s.conf.MaxBytes), - handler.GunzipHandler, - ) - chain = s.appendAuthHandler(fr, chain, verifier) - - for _, middleware := range s.middlewares { - chain = chain.Append(convertMiddleware(middleware)) +func WithJwtTransition(secret, prevSecret string) RouteOption { + return func(r *featuredRoutes) { + // why not validate prevSecret, because prevSecret is an already used one, + // even it not meet our requirement, we still need to allow the transition. + validateSecret(secret) + r.jwt.enabled = true + r.jwt.secret = secret + r.jwt.prevSecret = prevSecret } - handle := chain.ThenFunc(route.Handler) - - return router.Handle(route.Method, route.Path, handle) } -func (s *engine) bindRoutes(router httpx.Router) error { - metrics := s.createMetrics() +func WithMiddleware(middleware Middleware, rs ...Route) []Route { + routes := make([]Route, len(rs)) - for _, fr := range s.routes { - if err := s.bindFeaturedRoutes(router, fr, metrics); err != nil { - return err + for i := range rs { + route := rs[i] + routes[i] = Route{ + Method: route.Method, + Path: route.Path, + Handler: middleware(route.Handler), } } - return nil + return routes } -func (s *engine) createMetrics() *stat.Metrics { - var metrics *stat.Metrics - - if len(s.conf.Name) > 0 { - metrics = stat.NewMetrics(s.conf.Name) - } else { - metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", s.conf.Host, s.conf.Port)) +func WithPriority() RouteOption { + return func(r *featuredRoutes) { + r.priority = true } - - return metrics } -func (s *engine) getLogHandler() func(http.Handler) http.Handler { - if s.conf.Verbose { - return handler.DetailedLogHandler - } else { - return handler.LogHandler +func WithRouter(router httpx.Router) RunOption { + return func(server *Server) { + server.opts.start = func(srv *engine) error { + return srv.StartWithRouter(router) + } } } -func (s *engine) getShedder(priority bool) load.Shedder { - if priority && s.priorityShedder != nil { - return s.priorityShedder +func WithSignature(signature SignatureConf) RouteOption { + return func(r *featuredRoutes) { + r.signature.enabled = true + r.signature.Strict = signature.Strict + r.signature.Expiry = signature.Expiry + r.signature.PrivateKeys = signature.PrivateKeys } - return s.shedder } -func (s *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) { - if !signature.enabled { - return func(chain alice.Chain) alice.Chain { - return chain - }, nil +func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption { + return func(engine *Server) { + engine.ngin.SetUnauthorizedCallback(callback) } +} - if len(signature.PrivateKeys) == 0 { - if signature.Strict { - return nil, ErrSignatureConfig - } else { - return func(chain alice.Chain) alice.Chain { - return chain - }, nil - } +func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption { + return func(engine *Server) { + engine.ngin.SetUnsignedCallback(callback) } +} - decrypters := make(map[string]codec.RsaDecrypter) - for _, key := range signature.PrivateKeys { - fingerprint := key.Fingerprint - file := key.KeyFile - decrypter, err := codec.NewRsaDecrypter(file) - if err != nil { - return nil, err - } - - decrypters[fingerprint] = decrypter +func handleError(err error) { + // ErrServerClosed means the server is closed manually + if err == nil || err == http.ErrServerClosed { + return } - return func(chain alice.Chain) alice.Chain { - if s.unsignedCallback != nil { - return chain.Append(handler.ContentSecurityHandler( - decrypters, signature.Expiry, signature.Strict, s.unsignedCallback)) - } else { - return chain.Append(handler.ContentSecurityHandler( - decrypters, signature.Expiry, signature.Strict)) - } - }, nil -} - -func (s *engine) use(middleware Middleware) { - s.middlewares = append(s.middlewares, middleware) + logx.Error(err) + panic(err) } -func convertMiddleware(ware Middleware) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(ware(next.ServeHTTP)) +func validateSecret(secret string) { + if len(secret) < 8 { + panic("secret's length can't be less than 8") } } diff --git a/rest/ngin_test.go b/rest/server_test.go similarity index 100% rename from rest/ngin_test.go rename to rest/server_test.go