refactor: simplify tls config in rest (#1181)

master
Kevin Wan 3 years ago committed by GitHub
parent cd1f8da13f
commit 769d06c8ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -54,4 +54,5 @@ require (
k8s.io/api v0.20.10 k8s.io/api v0.20.10
k8s.io/apimachinery v0.20.10 k8s.io/apimachinery v0.20.10
k8s.io/client-go v0.20.10 k8s.io/client-go v0.20.10
k8s.io/utils v0.0.0-20201110183641-67b214c5f920
) )

@ -35,7 +35,7 @@ type (
KeyFile string `json:",optional"` KeyFile string `json:",optional"`
Verbose bool `json:",optional"` Verbose bool `json:",optional"`
MaxConns int `json:",default=10000"` MaxConns int `json:",default=10000"`
MaxBytes int64 `json:",default=1048576,range=[0:33554432]"` MaxBytes int64 `json:",default=1048576"`
// milliseconds // milliseconds
Timeout int64 `json:",default=3000"` Timeout int64 `json:",default=3000"`
CpuThreshold int64 `json:",default=900,range=[0:1000]"` CpuThreshold int64 `json:",default=900,range=[0:1000]"`

@ -47,58 +47,63 @@ func newEngine(c RestConf) *engine {
return srv return srv
} }
func (s *engine) AddRoutes(r featuredRoutes) { func (ng *engine) AddRoutes(r featuredRoutes) {
s.routes = append(s.routes, r) ng.routes = append(ng.routes, r)
} }
func (s *engine) SetUnauthorizedCallback(callback handler.UnauthorizedCallback) { func (ng *engine) SetUnauthorizedCallback(callback handler.UnauthorizedCallback) {
s.unauthorizedCallback = callback ng.unauthorizedCallback = callback
} }
func (s *engine) SetUnsignedCallback(callback handler.UnsignedCallback) { func (ng *engine) SetUnsignedCallback(callback handler.UnsignedCallback) {
s.unsignedCallback = callback ng.unsignedCallback = callback
} }
func (s *engine) Start() error { func (ng *engine) Start() error {
return s.StartWithRouter(router.NewRouter()) return ng.StartWithRouter(router.NewRouter())
} }
func (s *engine) StartWithRouter(router httpx.Router) error { func (ng *engine) StartWithRouter(router httpx.Router) error {
if err := s.bindRoutes(router); err != nil { if err := ng.bindRoutes(router); err != nil {
return err return err
} }
if len(s.conf.CertFile) == 0 && len(s.conf.KeyFile) == 0 { if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
return internal.StartHttp(s.conf.Host, s.conf.Port, router) return internal.StartHttp(ng.conf.Host, ng.conf.Port, router)
} }
return internal.StartHttps(s.conf.Host, s.conf.Port, s.conf.CertFile, s.conf.KeyFile, s.tlsConfig, router) return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile,
ng.conf.KeyFile, router, func(srv *http.Server) {
if ng.tlsConfig != nil {
srv.TLSConfig = ng.tlsConfig
}
})
} }
func (s *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain, func (ng *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain,
verifier func(alice.Chain) alice.Chain) alice.Chain { verifier func(alice.Chain) alice.Chain) alice.Chain {
if fr.jwt.enabled { if fr.jwt.enabled {
if len(fr.jwt.prevSecret) == 0 { if len(fr.jwt.prevSecret) == 0 {
chain = chain.Append(handler.Authorize(fr.jwt.secret, chain = chain.Append(handler.Authorize(fr.jwt.secret,
handler.WithUnauthorizedCallback(s.unauthorizedCallback))) handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
} else { } else {
chain = chain.Append(handler.Authorize(fr.jwt.secret, chain = chain.Append(handler.Authorize(fr.jwt.secret,
handler.WithPrevSecret(fr.jwt.prevSecret), handler.WithPrevSecret(fr.jwt.prevSecret),
handler.WithUnauthorizedCallback(s.unauthorizedCallback))) handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
} }
} }
return verifier(chain) return verifier(chain)
} }
func (s *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error { func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
verifier, err := s.signatureVerifier(fr.signature) verifier, err := ng.signatureVerifier(fr.signature)
if err != nil { if err != nil {
return err return err
} }
for _, route := range fr.routes { for _, route := range fr.routes {
if err := s.bindRoute(fr, router, metrics, route, verifier); err != nil { if err := ng.bindRoute(fr, router, metrics, route, verifier); err != nil {
return err return err
} }
} }
@ -106,24 +111,24 @@ func (s *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metr
return nil return nil
} }
func (s *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics, func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
route Route, verifier func(chain alice.Chain) alice.Chain) error { route Route, verifier func(chain alice.Chain) alice.Chain) error {
chain := alice.New( chain := alice.New(
handler.TracingHandler(s.conf.Name, route.Path), handler.TracingHandler(ng.conf.Name, route.Path),
s.getLogHandler(), ng.getLogHandler(),
handler.PrometheusHandler(route.Path), handler.PrometheusHandler(route.Path),
handler.MaxConns(s.conf.MaxConns), handler.MaxConns(ng.conf.MaxConns),
handler.BreakerHandler(route.Method, route.Path, metrics), handler.BreakerHandler(route.Method, route.Path, metrics),
handler.SheddingHandler(s.getShedder(fr.priority), metrics), handler.SheddingHandler(ng.getShedder(fr.priority), metrics),
handler.TimeoutHandler(time.Duration(s.conf.Timeout)*time.Millisecond), handler.TimeoutHandler(time.Duration(ng.conf.Timeout)*time.Millisecond),
handler.RecoverHandler, handler.RecoverHandler,
handler.MetricHandler(metrics), handler.MetricHandler(metrics),
handler.MaxBytesHandler(s.conf.MaxBytes), handler.MaxBytesHandler(ng.conf.MaxBytes),
handler.GunzipHandler, handler.GunzipHandler,
) )
chain = s.appendAuthHandler(fr, chain, verifier) chain = ng.appendAuthHandler(fr, chain, verifier)
for _, middleware := range s.middlewares { for _, middleware := range ng.middlewares {
chain = chain.Append(convertMiddleware(middleware)) chain = chain.Append(convertMiddleware(middleware))
} }
handle := chain.ThenFunc(route.Handler) handle := chain.ThenFunc(route.Handler)
@ -131,11 +136,11 @@ func (s *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat
return router.Handle(route.Method, route.Path, handle) return router.Handle(route.Method, route.Path, handle)
} }
func (s *engine) bindRoutes(router httpx.Router) error { func (ng *engine) bindRoutes(router httpx.Router) error {
metrics := s.createMetrics() metrics := ng.createMetrics()
for _, fr := range s.routes { for _, fr := range ng.routes {
if err := s.bindFeaturedRoutes(router, fr, metrics); err != nil { if err := ng.bindFeaturedRoutes(router, fr, metrics); err != nil {
return err return err
} }
} }
@ -143,35 +148,39 @@ func (s *engine) bindRoutes(router httpx.Router) error {
return nil return nil
} }
func (s *engine) createMetrics() *stat.Metrics { func (ng *engine) createMetrics() *stat.Metrics {
var metrics *stat.Metrics var metrics *stat.Metrics
if len(s.conf.Name) > 0 { if len(ng.conf.Name) > 0 {
metrics = stat.NewMetrics(s.conf.Name) metrics = stat.NewMetrics(ng.conf.Name)
} else { } else {
metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", s.conf.Host, s.conf.Port)) metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", ng.conf.Host, ng.conf.Port))
} }
return metrics return metrics
} }
func (s *engine) getLogHandler() func(http.Handler) http.Handler { func (ng *engine) getLogHandler() func(http.Handler) http.Handler {
if s.conf.Verbose { if ng.conf.Verbose {
return handler.DetailedLogHandler return handler.DetailedLogHandler
} }
return handler.LogHandler return handler.LogHandler
} }
func (s *engine) getShedder(priority bool) load.Shedder { func (ng *engine) getShedder(priority bool) load.Shedder {
if priority && s.priorityShedder != nil { if priority && ng.priorityShedder != nil {
return s.priorityShedder return ng.priorityShedder
} }
return s.shedder return ng.shedder
} }
func (s *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) { func (ng *engine) setTlsConfig(cfg *tls.Config) {
ng.tlsConfig = cfg
}
func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) {
if !signature.enabled { if !signature.enabled {
return func(chain alice.Chain) alice.Chain { return func(chain alice.Chain) alice.Chain {
return chain return chain
@ -201,9 +210,9 @@ func (s *engine) signatureVerifier(signature signatureSetting) (func(chain alice
} }
return func(chain alice.Chain) alice.Chain { return func(chain alice.Chain) alice.Chain {
if s.unsignedCallback != nil { if ng.unsignedCallback != nil {
return chain.Append(handler.ContentSecurityHandler( return chain.Append(handler.ContentSecurityHandler(
decrypters, signature.Expiry, signature.Strict, s.unsignedCallback)) decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback))
} }
return chain.Append(handler.ContentSecurityHandler( return chain.Append(handler.ContentSecurityHandler(
@ -211,8 +220,8 @@ func (s *engine) signatureVerifier(signature signatureSetting) (func(chain alice
}, nil }, nil
} }
func (s *engine) use(middleware Middleware) { func (ng *engine) use(middleware Middleware) {
s.middlewares = append(s.middlewares, middleware) ng.middlewares = append(ng.middlewares, middleware)
} }
func convertMiddleware(ware Middleware) func(http.Handler) http.Handler { func convertMiddleware(ware Middleware) func(http.Handler) http.Handler {

@ -2,38 +2,46 @@ package internal
import ( import (
"context" "context"
"crypto/tls"
"fmt" "fmt"
"net/http" "net/http"
"github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/core/proc" "github.com/tal-tech/go-zero/core/proc"
) )
// StartOption defines the method to customize http.Server.
type StartOption func(srv *http.Server)
// StartHttp starts a http server. // StartHttp starts a http server.
func StartHttp(host string, port int, handler http.Handler) error { func StartHttp(host string, port int, handler http.Handler, opts ...StartOption) error {
return start(host, port, handler, nil, func(srv *http.Server) error { return start(host, port, handler, func(srv *http.Server) error {
return srv.ListenAndServe() return srv.ListenAndServe()
}) }, opts...)
} }
// StartHttps starts a https server. // StartHttps starts a https server.
func StartHttps(host string, port int, certFile, keyFile string, tlsConfig *tls.Config, handler http.Handler) error { func StartHttps(host string, port int, certFile, keyFile string, handler http.Handler,
return start(host, port, handler, tlsConfig, func(srv *http.Server) error { opts ...StartOption) error {
return start(host, port, handler, func(srv *http.Server) error {
// certFile and keyFile are set in buildHttpsServer // certFile and keyFile are set in buildHttpsServer
return srv.ListenAndServeTLS(certFile, keyFile) return srv.ListenAndServeTLS(certFile, keyFile)
}) }, opts...)
} }
func start(host string, port int, handler http.Handler, tlsConfig *tls.Config, run func(srv *http.Server) error) (err error) { func start(host string, port int, handler http.Handler, run func(srv *http.Server) error,
opts ...StartOption) (err error) {
server := &http.Server{ server := &http.Server{
Addr: fmt.Sprintf("%s:%d", host, port), Addr: fmt.Sprintf("%s:%d", host, port),
Handler: handler, Handler: handler,
} }
if tlsConfig != nil { for _, opt := range opts {
server.TLSConfig = tlsConfig opt(server)
} }
waitForCalled := proc.AddWrapUpListener(func() { waitForCalled := proc.AddWrapUpListener(func() {
server.Shutdown(context.Background()) if e := server.Shutdown(context.Background()); err != nil {
logx.Error(e)
}
}) })
defer func() { defer func() {
if err == http.ErrServerClosed { if err == http.ErrServerClosed {

@ -48,8 +48,8 @@ func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
server := &Server{ server := &Server{
ngin: newEngine(c), ngin: newEngine(c),
opts: runOptions{ opts: runOptions{
start: func(srv *engine) error { start: func(ng *engine) error {
return srv.Start() return ng.Start()
}, },
}, },
} }
@ -171,8 +171,8 @@ func WithPriority() RouteOption {
// WithRouter returns a RunOption that make server run with given router. // WithRouter returns a RunOption that make server run with given router.
func WithRouter(router httpx.Router) RunOption { func WithRouter(router httpx.Router) RunOption {
return func(server *Server) { return func(server *Server) {
server.opts.start = func(srv *engine) error { server.opts.start = func(ng *engine) error {
return srv.StartWithRouter(router) return ng.StartWithRouter(router)
} }
} }
} }
@ -187,26 +187,24 @@ func WithSignature(signature SignatureConf) RouteOption {
} }
} }
// WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set. // WithTLSConfig returns a RunOption that with given tls config.
func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption { func WithTLSConfig(cfg *tls.Config) RunOption {
return func(engine *Server) { return func(srv *Server) {
engine.ngin.SetUnauthorizedCallback(callback) srv.ngin.setTlsConfig(cfg)
} }
} }
// WithTLSConfig returns a RunOption that with given tls config. // WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set.
func WithTLSConfig(cipherSuites []uint16) RunOption { func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
return func(engine *Server) { return func(srv *Server) {
engine.ngin.tlsConfig = &tls.Config{ srv.ngin.SetUnauthorizedCallback(callback)
CipherSuites: cipherSuites,
}
} }
} }
// WithUnsignedCallback returns a RunOption that with given unsigned callback set. // WithUnsignedCallback returns a RunOption that with given unsigned callback set.
func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption { func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption {
return func(engine *Server) { return func(srv *Server) {
engine.ngin.SetUnsignedCallback(callback) srv.ngin.SetUnsignedCallback(callback)
} }
} }

@ -227,8 +227,10 @@ Port: 54321
var cnf RestConf var cnf RestConf
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf)) assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
testConfig := []uint16{ testConfig := &tls.Config{
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, CipherSuites: []uint16{
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
},
} }
testCases := []struct { testCases := []struct {
@ -239,7 +241,7 @@ Port: 54321
{ {
c: cnf, c: cnf,
opts: []RunOption{WithTLSConfig(testConfig)}, opts: []RunOption{WithTLSConfig(testConfig)},
res: &tls.Config{CipherSuites: testConfig}, res: testConfig,
}, },
{ {
c: cnf, c: cnf,

Loading…
Cancel
Save