From 2b9fc26c38b915f9b0f878246d66dea207237fc0 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Thu, 31 Mar 2022 21:39:02 +0800 Subject: [PATCH] refactor: guard timeout on API files (#1726) --- core/stores/sqlx/utils.go | 3 ++- rest/engine.go | 17 +++++++-------- rest/server.go | 14 ++++++------ rest/server_test.go | 7 ++++++ tools/goctl/api/gogen/genroutes.go | 34 ++++++++++++++++++------------ 5 files changed, 44 insertions(+), 31 deletions(-) diff --git a/core/stores/sqlx/utils.go b/core/stores/sqlx/utils.go index 6aa966e7..95d2cdaa 100644 --- a/core/stores/sqlx/utils.go +++ b/core/stores/sqlx/utils.go @@ -75,6 +75,7 @@ func format(query string, args ...interface{}) (string, error) { break } } + if j > i+1 { index, err := strconv.Atoi(query[i+1 : j]) if err != nil { @@ -85,7 +86,7 @@ func format(query string, args ...interface{}) (string, error) { if index > argIndex { argIndex = index } - + index-- if index < 0 || numArgs <= index { return "", fmt.Errorf("error: wrong index %d in sql", index) diff --git a/rest/engine.go b/rest/engine.go index bdc982bd..0e238d11 100644 --- a/rest/engine.go +++ b/rest/engine.go @@ -119,16 +119,7 @@ func (ng *engine) bindRoutes(router httpx.Router) error { return nil } -func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration { - if timeout > 0 { - return timeout - } - - return time.Duration(ng.conf.Timeout) * time.Millisecond -} - func (ng *engine) checkedMaxBytes(bytes int64) int64 { - if bytes > 0 { return bytes } @@ -136,6 +127,14 @@ func (ng *engine) checkedMaxBytes(bytes int64) int64 { return ng.conf.MaxBytes } +func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration { + if timeout > 0 { + return timeout + } + + return time.Duration(ng.conf.Timeout) * time.Millisecond +} + func (ng *engine) createMetrics() *stat.Metrics { var metrics *stat.Metrics diff --git a/rest/server.go b/rest/server.go index 8839097c..669f7084 100644 --- a/rest/server.go +++ b/rest/server.go @@ -137,6 +137,13 @@ func WithJwtTransition(secret, prevSecret string) RouteOption { } } +// WithMaxBytes returns a RouteOption to set maxBytes with the given value. +func WithMaxBytes(maxBytes int64) RouteOption { + return func(r *featuredRoutes) { + r.maxBytes = maxBytes + } +} + // WithMiddlewares adds given middlewares to given routes. func WithMiddlewares(ms []Middleware, rs ...Route) []Route { for i := len(ms) - 1; i >= 0; i-- { @@ -223,13 +230,6 @@ func WithTimeout(timeout time.Duration) RouteOption { } } -// WithMaxBytes returns a RouteOption to set maxBytes with given value. -func WithMaxBytes(maxBytes int64) RouteOption { - return func(r *featuredRoutes) { - r.maxBytes = maxBytes - } -} - // WithTLSConfig returns a RunOption that with given tls config. func WithTLSConfig(cfg *tls.Config) RunOption { return func(svr *Server) { diff --git a/rest/server_test.go b/rest/server_test.go index 92537de6..c1c655c9 100644 --- a/rest/server_test.go +++ b/rest/server_test.go @@ -95,6 +95,13 @@ Port: 54321 } } +func TestWithMaxBytes(t *testing.T) { + const maxBytes = 1000 + var fr featuredRoutes + WithMaxBytes(maxBytes)(&fr) + assert.Equal(t, int64(maxBytes), fr.maxBytes) +} + func TestWithMiddleware(t *testing.T) { m := make(map[string]string) rt := router.NewRouter() diff --git a/tools/goctl/api/gogen/genroutes.go b/tools/goctl/api/gogen/genroutes.go index 2df6c78b..c6272dc3 100644 --- a/tools/goctl/api/gogen/genroutes.go +++ b/tools/goctl/api/gogen/genroutes.go @@ -24,7 +24,8 @@ const ( package handler import ( - "net/http" + "net/http"{{if .hasTimeout}} + "time"{{end}} {{.importPackages}} ) @@ -38,6 +39,7 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) { {{.routes}} {{.jwt}}{{.signature}} {{.prefix}} {{.timeout}} ) ` + timeoutThreshold = time.Millisecond ) var mapping = map[string]string{ @@ -59,7 +61,6 @@ type ( signatureEnabled bool authName string timeout string - timeoutEnable bool middlewares []string prefix string jwtTrans string @@ -83,6 +84,7 @@ func genRoutes(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error return err } + var hasTimeout bool gt := template.Must(template.New("groupTemplate").Parse(templateText)) for _, g := range groups { var gbuilder strings.Builder @@ -114,12 +116,19 @@ rest.WithPrefix("%s"),`, g.prefix) } var timeout string - if g.timeoutEnable { + if len(g.timeout) > 0 { duration, err := time.ParseDuration(g.timeout) if err != nil { - panic(err) + return err } - timeout = fmt.Sprintf("rest.WithTimeout(%d),", duration) + + // why we check this, maybe some users set value 1, it's 1ns, not 1s. + if duration < timeoutThreshold { + return fmt.Errorf("timeout should not less than 1ms, now %v", duration) + } + + timeout = fmt.Sprintf("rest.WithTimeout(%d * time.Millisecond),", duration/time.Millisecond) + hasTimeout = true } var routes string @@ -152,8 +161,8 @@ rest.WithPrefix("%s"),`, g.prefix) if err != nil { return err } - routeFilename = routeFilename + ".go" + routeFilename = routeFilename + ".go" filename := path.Join(dir, handlerDir, routeFilename) os.Remove(filename) @@ -165,7 +174,8 @@ rest.WithPrefix("%s"),`, g.prefix) category: category, templateFile: routesTemplateFile, builtinTemplate: routesTemplate, - data: map[string]string{ + data: map[string]interface{}{ + "hasTimeout": hasTimeout, "importPackages": genRouteImports(rootPkg, api), "routesAdditions": strings.TrimSpace(builder.String()), }, @@ -184,7 +194,8 @@ func genRouteImports(parentPkg string, api *spec.ApiSpec) string { continue } } - importSet.AddStr(fmt.Sprintf("%s \"%s\"", toPrefix(folder), pathx.JoinPackages(parentPkg, handlerDir, folder))) + importSet.AddStr(fmt.Sprintf("%s \"%s\"", toPrefix(folder), + pathx.JoinPackages(parentPkg, handlerDir, folder))) } } imports := importSet.KeysStr() @@ -218,12 +229,7 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) { }) } - timeout := g.GetAnnotation("timeout") - - if len(timeout) > 0 { - groupedRoutes.timeoutEnable = true - groupedRoutes.timeout = timeout - } + groupedRoutes.timeout = g.GetAnnotation("timeout") jwt := g.GetAnnotation("jwt") if len(jwt) > 0 {