chore: refine rest validator (#2928)

* chore: refine rest validator

* chore: add more tests

* chore: reformat code

* chore: add comments
master
Kevin Wan 2 years ago committed by GitHub
parent 92c8899f47
commit 66be213346
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,6 +4,7 @@ import (
"io" "io"
"net/http" "net/http"
"strings" "strings"
"sync/atomic"
"github.com/zeromicro/go-zero/core/mapping" "github.com/zeromicro/go-zero/core/mapping"
"github.com/zeromicro/go-zero/rest/internal/encoding" "github.com/zeromicro/go-zero/rest/internal/encoding"
@ -23,15 +24,13 @@ const (
var ( var (
formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues()) formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues())
pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues()) pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues())
xValidator Validator validator atomic.Value
) )
// Validator defines the interface for validating the request.
type Validator interface { type Validator interface {
Validate(data interface{}, lang string) error // Validate validates the request and parsed data.
} Validate(r *http.Request, data any) error
func SetValidator(validator Validator) {
xValidator = validator
} }
// Parse parses the request. // Parse parses the request.
@ -52,9 +51,10 @@ func Parse(r *http.Request, v any) error {
return err return err
} }
if xValidator != nil { if val := validator.Load(); val != nil {
return xValidator.Validate(v, r.Header.Get("Accept-Language")) return val.(Validator).Validate(r, v)
} }
return nil return nil
} }
@ -117,6 +117,13 @@ func ParsePath(r *http.Request, v any) error {
return pathUnmarshaler.Unmarshal(m, v) return pathUnmarshaler.Unmarshal(m, v)
} }
// SetValidator sets the validator.
// The validator is used to validate the request, only called in Parse,
// not in ParseHeaders, ParseForm, ParseHeader, ParseJsonBody, ParsePath.
func SetValidator(val Validator) {
validator.Store(val)
}
func withJsonBody(r *http.Request) bool { func withJsonBody(r *http.Request) bool {
return r.ContentLength > 0 && strings.Contains(r.Header.Get(header.ContentType), header.ApplicationJson) return r.ContentLength > 0 && strings.Contains(r.Header.Get(header.ContentType), header.ApplicationJson)
} }

@ -1,8 +1,10 @@
package httpx package httpx
import ( import (
"errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect"
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
@ -207,9 +209,23 @@ func TestParseJsonBody(t *testing.T) {
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
r.Header.Set(ContentType, header.JsonContentType) r.Header.Set(ContentType, header.JsonContentType)
assert.Nil(t, Parse(r, &v)) if assert.NoError(t, Parse(r, &v)) {
assert.Equal(t, "kevin", v.Name) assert.Equal(t, "kevin", v.Name)
assert.Equal(t, 18, v.Age) assert.Equal(t, 18, v.Age)
}
})
t.Run("bad body", func(t *testing.T) {
var v struct {
Name string `json:"name"`
Age int `json:"age"`
}
body := `{"name":"kevin", "ag": 18}`
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
r.Header.Set(ContentType, header.JsonContentType)
assert.Error(t, Parse(r, &v))
}) })
t.Run("hasn't body", func(t *testing.T) { t.Run("hasn't body", func(t *testing.T) {
@ -308,6 +324,36 @@ func TestParseHeaders_Error(t *testing.T) {
assert.NotNil(t, Parse(r, &v)) assert.NotNil(t, Parse(r, &v))
} }
func TestParseWithValidator(t *testing.T) {
SetValidator(mockValidator{})
var v struct {
Name string `form:"name"`
Age int `form:"age"`
Percent float64 `form:"percent,optional"`
}
r, err := http.NewRequest(http.MethodGet, "/a?name=hello&age=18&percent=3.4", http.NoBody)
assert.Nil(t, err)
if assert.NoError(t, Parse(r, &v)) {
assert.Equal(t, "hello", v.Name)
assert.Equal(t, 18, v.Age)
assert.Equal(t, 3.4, v.Percent)
}
}
func TestParseWithValidatorWithError(t *testing.T) {
SetValidator(mockValidator{})
var v struct {
Name string `form:"name"`
Age int `form:"age"`
Percent float64 `form:"percent,optional"`
}
r, err := http.NewRequest(http.MethodGet, "/a?name=world&age=18&percent=3.4", http.NoBody)
assert.Nil(t, err)
assert.Error(t, Parse(r, &v))
}
func BenchmarkParseRaw(b *testing.B) { func BenchmarkParseRaw(b *testing.B) {
r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?name=hello&age=18&percent=3.4", http.NoBody) r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?name=hello&age=18&percent=3.4", http.NoBody)
if err != nil { if err != nil {
@ -351,3 +397,16 @@ func BenchmarkParseAuto(b *testing.B) {
} }
} }
} }
type mockValidator struct{}
func (m mockValidator) Validate(r *http.Request, data any) error {
if r.URL.Path == "/a" {
val := reflect.ValueOf(data).Elem().FieldByName("Name").String()
if val != "hello" {
return errors.New("name is not hello")
}
}
return nil
}

Loading…
Cancel
Save