From 66be213346e55fabd3663fc86c9e8111246969cf Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Sun, 26 Feb 2023 21:58:58 +0800 Subject: [PATCH] chore: refine rest validator (#2928) * chore: refine rest validator * chore: add more tests * chore: reformat code * chore: add comments --- rest/httpx/requests.go | 23 ++++++++----- rest/httpx/requests_test.go | 65 +++++++++++++++++++++++++++++++++++-- 2 files changed, 77 insertions(+), 11 deletions(-) diff --git a/rest/httpx/requests.go b/rest/httpx/requests.go index b923d1a2..cd088bb0 100644 --- a/rest/httpx/requests.go +++ b/rest/httpx/requests.go @@ -4,6 +4,7 @@ import ( "io" "net/http" "strings" + "sync/atomic" "github.com/zeromicro/go-zero/core/mapping" "github.com/zeromicro/go-zero/rest/internal/encoding" @@ -23,15 +24,13 @@ const ( var ( formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues()) pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues()) - xValidator Validator + validator atomic.Value ) +// Validator defines the interface for validating the request. type Validator interface { - Validate(data interface{}, lang string) error -} - -func SetValidator(validator Validator) { - xValidator = validator + // Validate validates the request and parsed data. + Validate(r *http.Request, data any) error } // Parse parses the request. @@ -52,9 +51,10 @@ func Parse(r *http.Request, v any) error { return err } - if xValidator != nil { - return xValidator.Validate(v, r.Header.Get("Accept-Language")) + if val := validator.Load(); val != nil { + return val.(Validator).Validate(r, v) } + return nil } @@ -117,6 +117,13 @@ func ParsePath(r *http.Request, v any) error { return pathUnmarshaler.Unmarshal(m, v) } +// SetValidator sets the validator. +// The validator is used to validate the request, only called in Parse, +// not in ParseHeaders, ParseForm, ParseHeader, ParseJsonBody, ParsePath. +func SetValidator(val Validator) { + validator.Store(val) +} + func withJsonBody(r *http.Request) bool { return r.ContentLength > 0 && strings.Contains(r.Header.Get(header.ContentType), header.ApplicationJson) } diff --git a/rest/httpx/requests_test.go b/rest/httpx/requests_test.go index d6601f33..fa7dd9dc 100644 --- a/rest/httpx/requests_test.go +++ b/rest/httpx/requests_test.go @@ -1,8 +1,10 @@ package httpx import ( + "errors" "net/http" "net/http/httptest" + "reflect" "strconv" "strings" "testing" @@ -207,9 +209,23 @@ func TestParseJsonBody(t *testing.T) { r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) r.Header.Set(ContentType, header.JsonContentType) - assert.Nil(t, Parse(r, &v)) - assert.Equal(t, "kevin", v.Name) - assert.Equal(t, 18, v.Age) + if assert.NoError(t, Parse(r, &v)) { + assert.Equal(t, "kevin", v.Name) + assert.Equal(t, 18, v.Age) + } + }) + + t.Run("bad body", func(t *testing.T) { + var v struct { + Name string `json:"name"` + Age int `json:"age"` + } + + body := `{"name":"kevin", "ag": 18}` + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) + r.Header.Set(ContentType, header.JsonContentType) + + assert.Error(t, Parse(r, &v)) }) t.Run("hasn't body", func(t *testing.T) { @@ -308,6 +324,36 @@ func TestParseHeaders_Error(t *testing.T) { assert.NotNil(t, Parse(r, &v)) } +func TestParseWithValidator(t *testing.T) { + SetValidator(mockValidator{}) + var v struct { + Name string `form:"name"` + Age int `form:"age"` + Percent float64 `form:"percent,optional"` + } + + r, err := http.NewRequest(http.MethodGet, "/a?name=hello&age=18&percent=3.4", http.NoBody) + assert.Nil(t, err) + if assert.NoError(t, Parse(r, &v)) { + assert.Equal(t, "hello", v.Name) + assert.Equal(t, 18, v.Age) + assert.Equal(t, 3.4, v.Percent) + } +} + +func TestParseWithValidatorWithError(t *testing.T) { + SetValidator(mockValidator{}) + var v struct { + Name string `form:"name"` + Age int `form:"age"` + Percent float64 `form:"percent,optional"` + } + + r, err := http.NewRequest(http.MethodGet, "/a?name=world&age=18&percent=3.4", http.NoBody) + assert.Nil(t, err) + assert.Error(t, Parse(r, &v)) +} + func BenchmarkParseRaw(b *testing.B) { r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?name=hello&age=18&percent=3.4", http.NoBody) if err != nil { @@ -351,3 +397,16 @@ func BenchmarkParseAuto(b *testing.B) { } } } + +type mockValidator struct{} + +func (m mockValidator) Validate(r *http.Request, data any) error { + if r.URL.Path == "/a" { + val := reflect.ValueOf(data).Elem().FieldByName("Name").String() + if val != "hello" { + return errors.New("name is not hello") + } + } + + return nil +}