refactor: guard timeout on API files (#1726)

master
Kevin Wan 3 years ago committed by GitHub
parent 321dc2d410
commit 2b9fc26c38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -75,6 +75,7 @@ func format(query string, args ...interface{}) (string, error) {
break
}
}
if j > i+1 {
index, err := strconv.Atoi(query[i+1 : j])
if err != nil {

@ -119,16 +119,7 @@ func (ng *engine) bindRoutes(router httpx.Router) error {
return nil
}
func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration {
if timeout > 0 {
return timeout
}
return time.Duration(ng.conf.Timeout) * time.Millisecond
}
func (ng *engine) checkedMaxBytes(bytes int64) int64 {
if bytes > 0 {
return bytes
}
@ -136,6 +127,14 @@ func (ng *engine) checkedMaxBytes(bytes int64) int64 {
return ng.conf.MaxBytes
}
func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration {
if timeout > 0 {
return timeout
}
return time.Duration(ng.conf.Timeout) * time.Millisecond
}
func (ng *engine) createMetrics() *stat.Metrics {
var metrics *stat.Metrics

@ -137,6 +137,13 @@ func WithJwtTransition(secret, prevSecret string) RouteOption {
}
}
// WithMaxBytes returns a RouteOption to set maxBytes with the given value.
func WithMaxBytes(maxBytes int64) RouteOption {
return func(r *featuredRoutes) {
r.maxBytes = maxBytes
}
}
// WithMiddlewares adds given middlewares to given routes.
func WithMiddlewares(ms []Middleware, rs ...Route) []Route {
for i := len(ms) - 1; i >= 0; i-- {
@ -223,13 +230,6 @@ func WithTimeout(timeout time.Duration) RouteOption {
}
}
// WithMaxBytes returns a RouteOption to set maxBytes with given value.
func WithMaxBytes(maxBytes int64) RouteOption {
return func(r *featuredRoutes) {
r.maxBytes = maxBytes
}
}
// WithTLSConfig returns a RunOption that with given tls config.
func WithTLSConfig(cfg *tls.Config) RunOption {
return func(svr *Server) {

@ -95,6 +95,13 @@ Port: 54321
}
}
func TestWithMaxBytes(t *testing.T) {
const maxBytes = 1000
var fr featuredRoutes
WithMaxBytes(maxBytes)(&fr)
assert.Equal(t, int64(maxBytes), fr.maxBytes)
}
func TestWithMiddleware(t *testing.T) {
m := make(map[string]string)
rt := router.NewRouter()

@ -24,7 +24,8 @@ const (
package handler
import (
"net/http"
"net/http"{{if .hasTimeout}}
"time"{{end}}
{{.importPackages}}
)
@ -38,6 +39,7 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) {
{{.routes}} {{.jwt}}{{.signature}} {{.prefix}} {{.timeout}}
)
`
timeoutThreshold = time.Millisecond
)
var mapping = map[string]string{
@ -59,7 +61,6 @@ type (
signatureEnabled bool
authName string
timeout string
timeoutEnable bool
middlewares []string
prefix string
jwtTrans string
@ -83,6 +84,7 @@ func genRoutes(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error
return err
}
var hasTimeout bool
gt := template.Must(template.New("groupTemplate").Parse(templateText))
for _, g := range groups {
var gbuilder strings.Builder
@ -114,12 +116,19 @@ rest.WithPrefix("%s"),`, g.prefix)
}
var timeout string
if g.timeoutEnable {
if len(g.timeout) > 0 {
duration, err := time.ParseDuration(g.timeout)
if err != nil {
panic(err)
return err
}
// why we check this, maybe some users set value 1, it's 1ns, not 1s.
if duration < timeoutThreshold {
return fmt.Errorf("timeout should not less than 1ms, now %v", duration)
}
timeout = fmt.Sprintf("rest.WithTimeout(%d),", duration)
timeout = fmt.Sprintf("rest.WithTimeout(%d * time.Millisecond),", duration/time.Millisecond)
hasTimeout = true
}
var routes string
@ -152,8 +161,8 @@ rest.WithPrefix("%s"),`, g.prefix)
if err != nil {
return err
}
routeFilename = routeFilename + ".go"
routeFilename = routeFilename + ".go"
filename := path.Join(dir, handlerDir, routeFilename)
os.Remove(filename)
@ -165,7 +174,8 @@ rest.WithPrefix("%s"),`, g.prefix)
category: category,
templateFile: routesTemplateFile,
builtinTemplate: routesTemplate,
data: map[string]string{
data: map[string]interface{}{
"hasTimeout": hasTimeout,
"importPackages": genRouteImports(rootPkg, api),
"routesAdditions": strings.TrimSpace(builder.String()),
},
@ -184,7 +194,8 @@ func genRouteImports(parentPkg string, api *spec.ApiSpec) string {
continue
}
}
importSet.AddStr(fmt.Sprintf("%s \"%s\"", toPrefix(folder), pathx.JoinPackages(parentPkg, handlerDir, folder)))
importSet.AddStr(fmt.Sprintf("%s \"%s\"", toPrefix(folder),
pathx.JoinPackages(parentPkg, handlerDir, folder)))
}
}
imports := importSet.KeysStr()
@ -218,12 +229,7 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) {
})
}
timeout := g.GetAnnotation("timeout")
if len(timeout) > 0 {
groupedRoutes.timeoutEnable = true
groupedRoutes.timeout = timeout
}
groupedRoutes.timeout = g.GetAnnotation("timeout")
jwt := g.GetAnnotation("jwt")
if len(jwt) > 0 {

Loading…
Cancel
Save