refactor: guard timeout on API files (#1726)

master
Kevin Wan 3 years ago committed by GitHub
parent 321dc2d410
commit 2b9fc26c38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -75,6 +75,7 @@ func format(query string, args ...interface{}) (string, error) {
break break
} }
} }
if j > i+1 { if j > i+1 {
index, err := strconv.Atoi(query[i+1 : j]) index, err := strconv.Atoi(query[i+1 : j])
if err != nil { if err != nil {
@ -85,7 +86,7 @@ func format(query string, args ...interface{}) (string, error) {
if index > argIndex { if index > argIndex {
argIndex = index argIndex = index
} }
index-- index--
if index < 0 || numArgs <= index { if index < 0 || numArgs <= index {
return "", fmt.Errorf("error: wrong index %d in sql", index) return "", fmt.Errorf("error: wrong index %d in sql", index)

@ -119,16 +119,7 @@ func (ng *engine) bindRoutes(router httpx.Router) error {
return nil return nil
} }
func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration {
if timeout > 0 {
return timeout
}
return time.Duration(ng.conf.Timeout) * time.Millisecond
}
func (ng *engine) checkedMaxBytes(bytes int64) int64 { func (ng *engine) checkedMaxBytes(bytes int64) int64 {
if bytes > 0 { if bytes > 0 {
return bytes return bytes
} }
@ -136,6 +127,14 @@ func (ng *engine) checkedMaxBytes(bytes int64) int64 {
return ng.conf.MaxBytes return ng.conf.MaxBytes
} }
func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration {
if timeout > 0 {
return timeout
}
return time.Duration(ng.conf.Timeout) * time.Millisecond
}
func (ng *engine) createMetrics() *stat.Metrics { func (ng *engine) createMetrics() *stat.Metrics {
var metrics *stat.Metrics var metrics *stat.Metrics

@ -137,6 +137,13 @@ func WithJwtTransition(secret, prevSecret string) RouteOption {
} }
} }
// WithMaxBytes returns a RouteOption to set maxBytes with the given value.
func WithMaxBytes(maxBytes int64) RouteOption {
return func(r *featuredRoutes) {
r.maxBytes = maxBytes
}
}
// WithMiddlewares adds given middlewares to given routes. // WithMiddlewares adds given middlewares to given routes.
func WithMiddlewares(ms []Middleware, rs ...Route) []Route { func WithMiddlewares(ms []Middleware, rs ...Route) []Route {
for i := len(ms) - 1; i >= 0; i-- { for i := len(ms) - 1; i >= 0; i-- {
@ -223,13 +230,6 @@ func WithTimeout(timeout time.Duration) RouteOption {
} }
} }
// WithMaxBytes returns a RouteOption to set maxBytes with given value.
func WithMaxBytes(maxBytes int64) RouteOption {
return func(r *featuredRoutes) {
r.maxBytes = maxBytes
}
}
// WithTLSConfig returns a RunOption that with given tls config. // WithTLSConfig returns a RunOption that with given tls config.
func WithTLSConfig(cfg *tls.Config) RunOption { func WithTLSConfig(cfg *tls.Config) RunOption {
return func(svr *Server) { return func(svr *Server) {

@ -95,6 +95,13 @@ Port: 54321
} }
} }
func TestWithMaxBytes(t *testing.T) {
const maxBytes = 1000
var fr featuredRoutes
WithMaxBytes(maxBytes)(&fr)
assert.Equal(t, int64(maxBytes), fr.maxBytes)
}
func TestWithMiddleware(t *testing.T) { func TestWithMiddleware(t *testing.T) {
m := make(map[string]string) m := make(map[string]string)
rt := router.NewRouter() rt := router.NewRouter()

@ -24,7 +24,8 @@ const (
package handler package handler
import ( import (
"net/http" "net/http"{{if .hasTimeout}}
"time"{{end}}
{{.importPackages}} {{.importPackages}}
) )
@ -38,6 +39,7 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) {
{{.routes}} {{.jwt}}{{.signature}} {{.prefix}} {{.timeout}} {{.routes}} {{.jwt}}{{.signature}} {{.prefix}} {{.timeout}}
) )
` `
timeoutThreshold = time.Millisecond
) )
var mapping = map[string]string{ var mapping = map[string]string{
@ -59,7 +61,6 @@ type (
signatureEnabled bool signatureEnabled bool
authName string authName string
timeout string timeout string
timeoutEnable bool
middlewares []string middlewares []string
prefix string prefix string
jwtTrans string jwtTrans string
@ -83,6 +84,7 @@ func genRoutes(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error
return err return err
} }
var hasTimeout bool
gt := template.Must(template.New("groupTemplate").Parse(templateText)) gt := template.Must(template.New("groupTemplate").Parse(templateText))
for _, g := range groups { for _, g := range groups {
var gbuilder strings.Builder var gbuilder strings.Builder
@ -114,12 +116,19 @@ rest.WithPrefix("%s"),`, g.prefix)
} }
var timeout string var timeout string
if g.timeoutEnable { if len(g.timeout) > 0 {
duration, err := time.ParseDuration(g.timeout) duration, err := time.ParseDuration(g.timeout)
if err != nil { if err != nil {
panic(err) return err
} }
timeout = fmt.Sprintf("rest.WithTimeout(%d),", duration)
// why we check this, maybe some users set value 1, it's 1ns, not 1s.
if duration < timeoutThreshold {
return fmt.Errorf("timeout should not less than 1ms, now %v", duration)
}
timeout = fmt.Sprintf("rest.WithTimeout(%d * time.Millisecond),", duration/time.Millisecond)
hasTimeout = true
} }
var routes string var routes string
@ -152,8 +161,8 @@ rest.WithPrefix("%s"),`, g.prefix)
if err != nil { if err != nil {
return err return err
} }
routeFilename = routeFilename + ".go"
routeFilename = routeFilename + ".go"
filename := path.Join(dir, handlerDir, routeFilename) filename := path.Join(dir, handlerDir, routeFilename)
os.Remove(filename) os.Remove(filename)
@ -165,7 +174,8 @@ rest.WithPrefix("%s"),`, g.prefix)
category: category, category: category,
templateFile: routesTemplateFile, templateFile: routesTemplateFile,
builtinTemplate: routesTemplate, builtinTemplate: routesTemplate,
data: map[string]string{ data: map[string]interface{}{
"hasTimeout": hasTimeout,
"importPackages": genRouteImports(rootPkg, api), "importPackages": genRouteImports(rootPkg, api),
"routesAdditions": strings.TrimSpace(builder.String()), "routesAdditions": strings.TrimSpace(builder.String()),
}, },
@ -184,7 +194,8 @@ func genRouteImports(parentPkg string, api *spec.ApiSpec) string {
continue continue
} }
} }
importSet.AddStr(fmt.Sprintf("%s \"%s\"", toPrefix(folder), pathx.JoinPackages(parentPkg, handlerDir, folder))) importSet.AddStr(fmt.Sprintf("%s \"%s\"", toPrefix(folder),
pathx.JoinPackages(parentPkg, handlerDir, folder)))
} }
} }
imports := importSet.KeysStr() imports := importSet.KeysStr()
@ -218,12 +229,7 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) {
}) })
} }
timeout := g.GetAnnotation("timeout") groupedRoutes.timeout = g.GetAnnotation("timeout")
if len(timeout) > 0 {
groupedRoutes.timeoutEnable = true
groupedRoutes.timeout = timeout
}
jwt := g.GetAnnotation("jwt") jwt := g.GetAnnotation("jwt")
if len(jwt) > 0 { if len(jwt) > 0 {

Loading…
Cancel
Save