package rest import ( "context" "errors" "net/http" "net/http/httptest" "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" "github.com/zeromicro/go-zero/core/conf" "github.com/zeromicro/go-zero/core/logx" ) func TestNewEngine(t *testing.T) { yamls := []string{ `Name: foo Port: 54321 `, `Name: foo Port: 54321 CpuThreshold: 500 `, `Name: foo Port: 54321 CpuThreshold: 500 Verbose: true `, } routes := []featuredRoutes{ { jwt: jwtSetting{}, signature: signatureSetting{}, routes: []Route{{ Method: http.MethodGet, Path: "/", Handler: func(w http.ResponseWriter, r *http.Request) {}, }}, }, { priority: true, jwt: jwtSetting{}, signature: signatureSetting{}, routes: []Route{{ Method: http.MethodGet, Path: "/", Handler: func(w http.ResponseWriter, r *http.Request) {}, }}, }, { priority: true, jwt: jwtSetting{ enabled: true, }, signature: signatureSetting{}, routes: []Route{{ Method: http.MethodGet, Path: "/", Handler: func(w http.ResponseWriter, r *http.Request) {}, }}, }, { priority: true, jwt: jwtSetting{ enabled: true, prevSecret: "thesecret", }, signature: signatureSetting{}, routes: []Route{{ Method: http.MethodGet, Path: "/", Handler: func(w http.ResponseWriter, r *http.Request) {}, }}, }, { priority: true, jwt: jwtSetting{ enabled: true, }, signature: signatureSetting{}, routes: []Route{{ Method: http.MethodGet, Path: "/", Handler: func(w http.ResponseWriter, r *http.Request) {}, }}, }, { priority: true, jwt: jwtSetting{ enabled: true, }, signature: signatureSetting{ enabled: true, }, routes: []Route{{ Method: http.MethodGet, Path: "/", Handler: func(w http.ResponseWriter, r *http.Request) {}, }}, }, { priority: true, jwt: jwtSetting{ enabled: true, }, signature: signatureSetting{ enabled: true, SignatureConf: SignatureConf{ Strict: true, }, }, routes: []Route{{ Method: http.MethodGet, Path: "/", Handler: func(w http.ResponseWriter, r *http.Request) {}, }}, }, { priority: true, jwt: jwtSetting{ enabled: true, }, signature: signatureSetting{ enabled: true, SignatureConf: SignatureConf{ Strict: true, PrivateKeys: []PrivateKeyConf{ { Fingerprint: "a", KeyFile: "b", }, }, }, }, routes: []Route{{ Method: http.MethodGet, Path: "/", Handler: func(w http.ResponseWriter, r *http.Request) {}, }}, }, } for _, yaml := range yamls { for _, route := range routes { var cnf RestConf assert.Nil(t, conf.LoadFromYamlBytes([]byte(yaml), &cnf)) ng := newEngine(cnf) ng.addRoutes(route) ng.use(func(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { next.ServeHTTP(w, r) } }) assert.NotNil(t, ng.start(mockedRouter{})) } } } func TestEngine_checkedTimeout(t *testing.T) { tests := []struct { name string timeout time.Duration expect time.Duration }{ { name: "not set", expect: time.Second, }, { name: "less", timeout: time.Millisecond * 500, expect: time.Millisecond * 500, }, { name: "equal", timeout: time.Second, expect: time.Second, }, { name: "more", timeout: time.Millisecond * 1500, expect: time.Millisecond * 1500, }, } ng := newEngine(RestConf{ Timeout: 1000, }) for _, test := range tests { assert.Equal(t, test.expect, ng.checkedTimeout(test.timeout)) } } func TestEngine_checkedMaxBytes(t *testing.T) { tests := []struct { name string maxBytes int64 expect int64 }{ { name: "not set", expect: 1000, }, { name: "less", maxBytes: 500, expect: 500, }, { name: "equal", maxBytes: 1000, expect: 1000, }, { name: "more", maxBytes: 1500, expect: 1500, }, } ng := newEngine(RestConf{ MaxBytes: 1000, }) for _, test := range tests { assert.Equal(t, test.expect, ng.checkedMaxBytes(test.maxBytes)) } } func TestEngine_checkedChain(t *testing.T) { var called int32 middleware1 := func() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { atomic.AddInt32(&called, 1) next.ServeHTTP(w, r) atomic.AddInt32(&called, 1) }) } } middleware2 := func() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { atomic.AddInt32(&called, 1) next.ServeHTTP(w, r) atomic.AddInt32(&called, 1) }) } } server := MustNewServer(RestConf{}, DisableDefaultMiddlewares()) server.Use(ToMiddleware(middleware1())) server.Use(ToMiddleware(middleware2())) server.router = chainRouter{} server.AddRoutes( []Route{ { Method: http.MethodGet, Path: "/", Handler: func(_ http.ResponseWriter, _ *http.Request) { atomic.AddInt32(&called, 1) }, }, }, ) server.ngin.bindRoutes(chainRouter{}) assert.Equal(t, int32(5), atomic.LoadInt32(&called)) } func TestEngine_notFoundHandler(t *testing.T) { logx.Disable() ng := newEngine(RestConf{}) ts := httptest.NewServer(ng.notFoundHandler(nil)) defer ts.Close() client := ts.Client() err := func(ctx context.Context) error { req, err := http.NewRequest("GET", ts.URL+"/bad", nil) assert.Nil(t, err) res, err := client.Do(req) assert.Nil(t, err) assert.Equal(t, http.StatusNotFound, res.StatusCode) return res.Body.Close() }(context.Background()) assert.Nil(t, err) } func TestEngine_notFoundHandlerNotNil(t *testing.T) { logx.Disable() ng := newEngine(RestConf{}) var called int32 ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { atomic.AddInt32(&called, 1) }))) defer ts.Close() client := ts.Client() err := func(ctx context.Context) error { req, err := http.NewRequest("GET", ts.URL+"/bad", nil) assert.Nil(t, err) res, err := client.Do(req) assert.Nil(t, err) assert.Equal(t, http.StatusNotFound, res.StatusCode) return res.Body.Close() }(context.Background()) assert.Nil(t, err) assert.Equal(t, int32(1), atomic.LoadInt32(&called)) } func TestEngine_notFoundHandlerNotNilWriteHeader(t *testing.T) { logx.Disable() ng := newEngine(RestConf{}) var called int32 ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { atomic.AddInt32(&called, 1) w.WriteHeader(http.StatusExpectationFailed) }))) defer ts.Close() client := ts.Client() err := func(ctx context.Context) error { req, err := http.NewRequest("GET", ts.URL+"/bad", nil) assert.Nil(t, err) res, err := client.Do(req) assert.Nil(t, err) assert.Equal(t, http.StatusExpectationFailed, res.StatusCode) return res.Body.Close() }(context.Background()) assert.Nil(t, err) assert.Equal(t, int32(1), atomic.LoadInt32(&called)) } func TestEngine_withTimeout(t *testing.T) { logx.Disable() tests := []struct { name string timeout int64 }{ { name: "not set", }, { name: "set", timeout: 1000, }, } for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { ng := newEngine(RestConf{Timeout: test.timeout}) svr := &http.Server{} ng.withTimeout()(svr) assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout) assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout) assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*9/10, svr.WriteTimeout) assert.Equal(t, time.Duration(0), svr.IdleTimeout) }) } } type mockedRouter struct{} func (m mockedRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) { } func (m mockedRouter) Handle(_, _ string, _ http.Handler) error { return errors.New("foo") } func (m mockedRouter) SetNotFoundHandler(_ http.Handler) { } func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) { } type chainRouter struct{} func (c chainRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) { } func (c chainRouter) Handle(_, _ string, handler http.Handler) error { handler.ServeHTTP(nil, nil) return nil } func (c chainRouter) SetNotFoundHandler(_ http.Handler) { } func (c chainRouter) SetNotAllowedHandler(_ http.Handler) { }