diff --git a/rest/engine.go b/rest/engine.go index 216d3cad..75fcbc9b 100644 --- a/rest/engine.go +++ b/rest/engine.go @@ -1,6 +1,7 @@ package rest import ( + "crypto/tls" "errors" "fmt" "net/http" @@ -30,6 +31,7 @@ type engine struct { middlewares []Middleware shedder load.Shedder priorityShedder load.Shedder + tlsConfig *tls.Config } func newEngine(c RestConf) *engine { @@ -70,7 +72,7 @@ func (s *engine) StartWithRouter(router httpx.Router) error { return internal.StartHttp(s.conf.Host, s.conf.Port, router) } - return internal.StartHttps(s.conf.Host, s.conf.Port, s.conf.CertFile, s.conf.KeyFile, router) + return internal.StartHttps(s.conf.Host, s.conf.Port, s.conf.CertFile, s.conf.KeyFile, s.tlsConfig, router) } func (s *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain, diff --git a/rest/internal/starter.go b/rest/internal/starter.go index e378b6d6..5af8b374 100644 --- a/rest/internal/starter.go +++ b/rest/internal/starter.go @@ -2,6 +2,7 @@ package internal import ( "context" + "crypto/tls" "fmt" "net/http" @@ -10,24 +11,27 @@ import ( // StartHttp starts a http server. func StartHttp(host string, port int, handler http.Handler) error { - return start(host, port, handler, func(srv *http.Server) error { + return start(host, port, handler, nil, func(srv *http.Server) error { return srv.ListenAndServe() }) } // StartHttps starts a https server. -func StartHttps(host string, port int, certFile, keyFile string, handler http.Handler) error { - return start(host, port, handler, func(srv *http.Server) error { +func StartHttps(host string, port int, certFile, keyFile string, tlsConfig *tls.Config, handler http.Handler) error { + return start(host, port, handler, tlsConfig, func(srv *http.Server) error { // certFile and keyFile are set in buildHttpsServer return srv.ListenAndServeTLS(certFile, keyFile) }) } -func start(host string, port int, handler http.Handler, run func(srv *http.Server) error) (err error) { +func start(host string, port int, handler http.Handler, tlsConfig *tls.Config, run func(srv *http.Server) error) (err error) { server := &http.Server{ Addr: fmt.Sprintf("%s:%d", host, port), Handler: handler, } + if tlsConfig != nil { + server.TLSConfig = tlsConfig + } waitForCalled := proc.AddWrapUpListener(func() { server.Shutdown(context.Background()) }) diff --git a/rest/server.go b/rest/server.go index fa366aab..4a7b3605 100644 --- a/rest/server.go +++ b/rest/server.go @@ -1,6 +1,7 @@ package rest import ( + "crypto/tls" "log" "net/http" @@ -193,6 +194,15 @@ func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption { } } +// WithTLSConfig returns a RunOption that with given tls config. +func WithTLSConfig(cipherSuites []uint16) RunOption { + return func(engine *Server) { + engine.ngin.tlsConfig = &tls.Config{ + CipherSuites: cipherSuites, + } + } +} + // WithUnsignedCallback returns a RunOption that with given unsigned callback set. func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption { return func(engine *Server) { diff --git a/rest/server_test.go b/rest/server_test.go index 2c79884d..3958e796 100644 --- a/rest/server_test.go +++ b/rest/server_test.go @@ -1,6 +1,7 @@ package rest import ( + "crypto/tls" "fmt" "io" "net/http" @@ -217,3 +218,39 @@ func TestWithPriority(t *testing.T) { WithPriority()(&fr) assert.True(t, fr.priority) } + +func TestWithTLSConfig(t *testing.T) { + const configYaml = ` +Name: foo +Port: 54321 +` + var cnf RestConf + assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf)) + + testConfig := []uint16{ + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + } + + testCases := []struct { + c RestConf + opts []RunOption + res *tls.Config + }{ + { + c: cnf, + opts: []RunOption{WithTLSConfig(testConfig)}, + res: &tls.Config{CipherSuites: testConfig}, + }, + { + c: cnf, + opts: []RunOption{WithUnsignedCallback(nil)}, + res: nil, + }, + } + + for _, testCase := range testCases { + srv, err := NewServer(testCase.c, testCase.opts...) + assert.Nil(t, err) + assert.Equal(t, srv.ngin.tlsConfig, testCase.res) + } +}