diff --git a/rest/engine.go b/rest/engine.go index 86e8dead..bdc982bd 100644 --- a/rest/engine.go +++ b/rest/engine.go @@ -94,7 +94,7 @@ func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *sta handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)), handler.RecoverHandler, handler.MetricHandler(metrics), - handler.MaxBytesHandler(ng.conf.MaxBytes), + handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)), handler.GunzipHandler, ) chain = ng.appendAuthHandler(fr, chain, verifier) @@ -127,6 +127,15 @@ func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration { return time.Duration(ng.conf.Timeout) * time.Millisecond } +func (ng *engine) checkedMaxBytes(bytes int64) int64 { + + if bytes > 0 { + return bytes + } + + return ng.conf.MaxBytes +} + func (ng *engine) createMetrics() *stat.Metrics { var metrics *stat.Metrics diff --git a/rest/engine_test.go b/rest/engine_test.go index ca540149..8b8ff528 100644 --- a/rest/engine_test.go +++ b/rest/engine_test.go @@ -194,6 +194,41 @@ func TestEngine_checkedTimeout(t *testing.T) { } } +func TestEngine_checkedMaxBytes(t *testing.T) { + tests := []struct { + name string + maxBytes int64 + expect int64 + }{ + { + name: "not set", + expect: 1000, + }, + { + name: "less", + maxBytes: 500, + expect: 500, + }, + { + name: "equal", + maxBytes: 1000, + expect: 1000, + }, + { + name: "more", + maxBytes: 1500, + expect: 1500, + }, + } + + ng := newEngine(RestConf{ + MaxBytes: 1000, + }) + for _, test := range tests { + assert.Equal(t, test.expect, ng.checkedMaxBytes(test.maxBytes)) + } +} + func TestEngine_notFoundHandler(t *testing.T) { logx.Disable() diff --git a/rest/server.go b/rest/server.go index 793faffa..8839097c 100644 --- a/rest/server.go +++ b/rest/server.go @@ -223,6 +223,13 @@ func WithTimeout(timeout time.Duration) RouteOption { } } +// WithMaxBytes returns a RouteOption to set maxBytes with given value. +func WithMaxBytes(maxBytes int64) RouteOption { + return func(r *featuredRoutes) { + r.maxBytes = maxBytes + } +} + // WithTLSConfig returns a RunOption that with given tls config. func WithTLSConfig(cfg *tls.Config) RunOption { return func(svr *Server) { diff --git a/rest/types.go b/rest/types.go index cc4636c5..f7be7996 100644 --- a/rest/types.go +++ b/rest/types.go @@ -36,5 +36,6 @@ type ( jwt jwtSetting signature signatureSetting routes []Route + maxBytes int64 } ) diff --git a/tools/goctl/api/gogen/genroutes.go b/tools/goctl/api/gogen/genroutes.go index 24161360..2df6c78b 100644 --- a/tools/goctl/api/gogen/genroutes.go +++ b/tools/goctl/api/gogen/genroutes.go @@ -7,6 +7,7 @@ import ( "sort" "strings" "text/template" + "time" "github.com/zeromicro/go-zero/core/collection" "github.com/zeromicro/go-zero/tools/goctl/api/spec" @@ -34,7 +35,7 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) { ` routesAdditionTemplate = ` server.AddRoutes( - {{.routes}} {{.jwt}}{{.signature}} {{.prefix}} + {{.routes}} {{.jwt}}{{.signature}} {{.prefix}} {{.timeout}} ) ` ) @@ -57,6 +58,8 @@ type ( jwtEnabled bool signatureEnabled bool authName string + timeout string + timeoutEnable bool middlewares []string prefix string jwtTrans string @@ -110,6 +113,15 @@ func genRoutes(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error rest.WithPrefix("%s"),`, g.prefix) } + var timeout string + if g.timeoutEnable { + duration, err := time.ParseDuration(g.timeout) + if err != nil { + panic(err) + } + timeout = fmt.Sprintf("rest.WithTimeout(%d),", duration) + } + var routes string if len(g.middlewares) > 0 { gbuilder.WriteString("\n}...,") @@ -130,6 +142,7 @@ rest.WithPrefix("%s"),`, g.prefix) "jwt": jwt, "signature": signature, "prefix": prefix, + "timeout": timeout, }); err != nil { return err } @@ -205,6 +218,13 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) { }) } + timeout := g.GetAnnotation("timeout") + + if len(timeout) > 0 { + groupedRoutes.timeoutEnable = true + groupedRoutes.timeout = timeout + } + jwt := g.GetAnnotation("jwt") if len(jwt) > 0 { groupedRoutes.authName = jwt