refactor: simplify tls config in rest (#1181)

master
Kevin Wan 3 years ago committed by GitHub
parent cd1f8da13f
commit 769d06c8ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -54,4 +54,5 @@ require (
k8s.io/api v0.20.10
k8s.io/apimachinery v0.20.10
k8s.io/client-go v0.20.10
k8s.io/utils v0.0.0-20201110183641-67b214c5f920
)

@ -35,7 +35,7 @@ type (
KeyFile string `json:",optional"`
Verbose bool `json:",optional"`
MaxConns int `json:",default=10000"`
MaxBytes int64 `json:",default=1048576,range=[0:33554432]"`
MaxBytes int64 `json:",default=1048576"`
// milliseconds
Timeout int64 `json:",default=3000"`
CpuThreshold int64 `json:",default=900,range=[0:1000]"`

@ -47,58 +47,63 @@ func newEngine(c RestConf) *engine {
return srv
}
func (s *engine) AddRoutes(r featuredRoutes) {
s.routes = append(s.routes, r)
func (ng *engine) AddRoutes(r featuredRoutes) {
ng.routes = append(ng.routes, r)
}
func (s *engine) SetUnauthorizedCallback(callback handler.UnauthorizedCallback) {
s.unauthorizedCallback = callback
func (ng *engine) SetUnauthorizedCallback(callback handler.UnauthorizedCallback) {
ng.unauthorizedCallback = callback
}
func (s *engine) SetUnsignedCallback(callback handler.UnsignedCallback) {
s.unsignedCallback = callback
func (ng *engine) SetUnsignedCallback(callback handler.UnsignedCallback) {
ng.unsignedCallback = callback
}
func (s *engine) Start() error {
return s.StartWithRouter(router.NewRouter())
func (ng *engine) Start() error {
return ng.StartWithRouter(router.NewRouter())
}
func (s *engine) StartWithRouter(router httpx.Router) error {
if err := s.bindRoutes(router); err != nil {
func (ng *engine) StartWithRouter(router httpx.Router) error {
if err := ng.bindRoutes(router); err != nil {
return err
}
if len(s.conf.CertFile) == 0 && len(s.conf.KeyFile) == 0 {
return internal.StartHttp(s.conf.Host, s.conf.Port, router)
if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
return internal.StartHttp(ng.conf.Host, ng.conf.Port, router)
}
return internal.StartHttps(s.conf.Host, s.conf.Port, s.conf.CertFile, s.conf.KeyFile, s.tlsConfig, router)
return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile,
ng.conf.KeyFile, router, func(srv *http.Server) {
if ng.tlsConfig != nil {
srv.TLSConfig = ng.tlsConfig
}
})
}
func (s *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain,
func (ng *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain,
verifier func(alice.Chain) alice.Chain) alice.Chain {
if fr.jwt.enabled {
if len(fr.jwt.prevSecret) == 0 {
chain = chain.Append(handler.Authorize(fr.jwt.secret,
handler.WithUnauthorizedCallback(s.unauthorizedCallback)))
handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
} else {
chain = chain.Append(handler.Authorize(fr.jwt.secret,
handler.WithPrevSecret(fr.jwt.prevSecret),
handler.WithUnauthorizedCallback(s.unauthorizedCallback)))
handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
}
}
return verifier(chain)
}
func (s *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
verifier, err := s.signatureVerifier(fr.signature)
func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
verifier, err := ng.signatureVerifier(fr.signature)
if err != nil {
return err
}
for _, route := range fr.routes {
if err := s.bindRoute(fr, router, metrics, route, verifier); err != nil {
if err := ng.bindRoute(fr, router, metrics, route, verifier); err != nil {
return err
}
}
@ -106,24 +111,24 @@ func (s *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metr
return nil
}
func (s *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
route Route, verifier func(chain alice.Chain) alice.Chain) error {
chain := alice.New(
handler.TracingHandler(s.conf.Name, route.Path),
s.getLogHandler(),
handler.TracingHandler(ng.conf.Name, route.Path),
ng.getLogHandler(),
handler.PrometheusHandler(route.Path),
handler.MaxConns(s.conf.MaxConns),
handler.MaxConns(ng.conf.MaxConns),
handler.BreakerHandler(route.Method, route.Path, metrics),
handler.SheddingHandler(s.getShedder(fr.priority), metrics),
handler.TimeoutHandler(time.Duration(s.conf.Timeout)*time.Millisecond),
handler.SheddingHandler(ng.getShedder(fr.priority), metrics),
handler.TimeoutHandler(time.Duration(ng.conf.Timeout)*time.Millisecond),
handler.RecoverHandler,
handler.MetricHandler(metrics),
handler.MaxBytesHandler(s.conf.MaxBytes),
handler.MaxBytesHandler(ng.conf.MaxBytes),
handler.GunzipHandler,
)
chain = s.appendAuthHandler(fr, chain, verifier)
chain = ng.appendAuthHandler(fr, chain, verifier)
for _, middleware := range s.middlewares {
for _, middleware := range ng.middlewares {
chain = chain.Append(convertMiddleware(middleware))
}
handle := chain.ThenFunc(route.Handler)
@ -131,11 +136,11 @@ func (s *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat
return router.Handle(route.Method, route.Path, handle)
}
func (s *engine) bindRoutes(router httpx.Router) error {
metrics := s.createMetrics()
func (ng *engine) bindRoutes(router httpx.Router) error {
metrics := ng.createMetrics()
for _, fr := range s.routes {
if err := s.bindFeaturedRoutes(router, fr, metrics); err != nil {
for _, fr := range ng.routes {
if err := ng.bindFeaturedRoutes(router, fr, metrics); err != nil {
return err
}
}
@ -143,35 +148,39 @@ func (s *engine) bindRoutes(router httpx.Router) error {
return nil
}
func (s *engine) createMetrics() *stat.Metrics {
func (ng *engine) createMetrics() *stat.Metrics {
var metrics *stat.Metrics
if len(s.conf.Name) > 0 {
metrics = stat.NewMetrics(s.conf.Name)
if len(ng.conf.Name) > 0 {
metrics = stat.NewMetrics(ng.conf.Name)
} else {
metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", s.conf.Host, s.conf.Port))
metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", ng.conf.Host, ng.conf.Port))
}
return metrics
}
func (s *engine) getLogHandler() func(http.Handler) http.Handler {
if s.conf.Verbose {
func (ng *engine) getLogHandler() func(http.Handler) http.Handler {
if ng.conf.Verbose {
return handler.DetailedLogHandler
}
return handler.LogHandler
}
func (s *engine) getShedder(priority bool) load.Shedder {
if priority && s.priorityShedder != nil {
return s.priorityShedder
func (ng *engine) getShedder(priority bool) load.Shedder {
if priority && ng.priorityShedder != nil {
return ng.priorityShedder
}
return ng.shedder
}
return s.shedder
func (ng *engine) setTlsConfig(cfg *tls.Config) {
ng.tlsConfig = cfg
}
func (s *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) {
func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) {
if !signature.enabled {
return func(chain alice.Chain) alice.Chain {
return chain
@ -201,9 +210,9 @@ func (s *engine) signatureVerifier(signature signatureSetting) (func(chain alice
}
return func(chain alice.Chain) alice.Chain {
if s.unsignedCallback != nil {
if ng.unsignedCallback != nil {
return chain.Append(handler.ContentSecurityHandler(
decrypters, signature.Expiry, signature.Strict, s.unsignedCallback))
decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback))
}
return chain.Append(handler.ContentSecurityHandler(
@ -211,8 +220,8 @@ func (s *engine) signatureVerifier(signature signatureSetting) (func(chain alice
}, nil
}
func (s *engine) use(middleware Middleware) {
s.middlewares = append(s.middlewares, middleware)
func (ng *engine) use(middleware Middleware) {
ng.middlewares = append(ng.middlewares, middleware)
}
func convertMiddleware(ware Middleware) func(http.Handler) http.Handler {

@ -2,38 +2,46 @@ package internal
import (
"context"
"crypto/tls"
"fmt"
"net/http"
"github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/core/proc"
)
// StartOption defines the method to customize http.Server.
type StartOption func(srv *http.Server)
// StartHttp starts a http server.
func StartHttp(host string, port int, handler http.Handler) error {
return start(host, port, handler, nil, func(srv *http.Server) error {
func StartHttp(host string, port int, handler http.Handler, opts ...StartOption) error {
return start(host, port, handler, func(srv *http.Server) error {
return srv.ListenAndServe()
})
}, opts...)
}
// StartHttps starts a https server.
func StartHttps(host string, port int, certFile, keyFile string, tlsConfig *tls.Config, handler http.Handler) error {
return start(host, port, handler, tlsConfig, func(srv *http.Server) error {
func StartHttps(host string, port int, certFile, keyFile string, handler http.Handler,
opts ...StartOption) error {
return start(host, port, handler, func(srv *http.Server) error {
// certFile and keyFile are set in buildHttpsServer
return srv.ListenAndServeTLS(certFile, keyFile)
})
}, opts...)
}
func start(host string, port int, handler http.Handler, tlsConfig *tls.Config, run func(srv *http.Server) error) (err error) {
func start(host string, port int, handler http.Handler, run func(srv *http.Server) error,
opts ...StartOption) (err error) {
server := &http.Server{
Addr: fmt.Sprintf("%s:%d", host, port),
Handler: handler,
}
if tlsConfig != nil {
server.TLSConfig = tlsConfig
for _, opt := range opts {
opt(server)
}
waitForCalled := proc.AddWrapUpListener(func() {
server.Shutdown(context.Background())
if e := server.Shutdown(context.Background()); err != nil {
logx.Error(e)
}
})
defer func() {
if err == http.ErrServerClosed {

@ -48,8 +48,8 @@ func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
server := &Server{
ngin: newEngine(c),
opts: runOptions{
start: func(srv *engine) error {
return srv.Start()
start: func(ng *engine) error {
return ng.Start()
},
},
}
@ -171,8 +171,8 @@ func WithPriority() RouteOption {
// WithRouter returns a RunOption that make server run with given router.
func WithRouter(router httpx.Router) RunOption {
return func(server *Server) {
server.opts.start = func(srv *engine) error {
return srv.StartWithRouter(router)
server.opts.start = func(ng *engine) error {
return ng.StartWithRouter(router)
}
}
}
@ -187,26 +187,24 @@ func WithSignature(signature SignatureConf) RouteOption {
}
}
// WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set.
func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
return func(engine *Server) {
engine.ngin.SetUnauthorizedCallback(callback)
// WithTLSConfig returns a RunOption that with given tls config.
func WithTLSConfig(cfg *tls.Config) RunOption {
return func(srv *Server) {
srv.ngin.setTlsConfig(cfg)
}
}
// WithTLSConfig returns a RunOption that with given tls config.
func WithTLSConfig(cipherSuites []uint16) RunOption {
return func(engine *Server) {
engine.ngin.tlsConfig = &tls.Config{
CipherSuites: cipherSuites,
}
// WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set.
func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
return func(srv *Server) {
srv.ngin.SetUnauthorizedCallback(callback)
}
}
// WithUnsignedCallback returns a RunOption that with given unsigned callback set.
func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption {
return func(engine *Server) {
engine.ngin.SetUnsignedCallback(callback)
return func(srv *Server) {
srv.ngin.SetUnsignedCallback(callback)
}
}

@ -227,8 +227,10 @@ Port: 54321
var cnf RestConf
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
testConfig := []uint16{
testConfig := &tls.Config{
CipherSuites: []uint16{
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
},
}
testCases := []struct {
@ -239,7 +241,7 @@ Port: 54321
{
c: cnf,
opts: []RunOption{WithTLSConfig(testConfig)},
res: &tls.Config{CipherSuites: testConfig},
res: testConfig,
},
{
c: cnf,

Loading…
Cancel
Save