From ce4eb6ed6147209c56290e2e584c965046708a3b Mon Sep 17 00:00:00 2001 From: chen quan Date: Sun, 23 Apr 2023 22:22:03 +0800 Subject: [PATCH] fix: fixed #2945 (#2953) Co-authored-by: Kevin Wan --- core/codec/rsa_test.go | 2 ++ rest/engine.go | 7 ++-- rest/engine_test.go | 47 ++++++++++++++++++++++++++ rest/handler/contentsecurityhandler.go | 8 ++++- rest/handler/cryptionhandler.go | 11 ++++-- 5 files changed, 67 insertions(+), 8 deletions(-) diff --git a/core/codec/rsa_test.go b/core/codec/rsa_test.go index 407952bb..68ce6435 100644 --- a/core/codec/rsa_test.go +++ b/core/codec/rsa_test.go @@ -2,6 +2,7 @@ package codec import ( "encoding/base64" + "os" "testing" "github.com/stretchr/testify/assert" @@ -41,6 +42,7 @@ func TestCryption(t *testing.T) { file, err := fs.TempFilenameWithText(priKey) assert.Nil(t, err) + defer os.Remove(file) dec, err := NewRsaDecrypter(file) assert.Nil(t, err) actual, err := dec.Decrypt(ret) diff --git a/rest/engine.go b/rest/engine.go index 68a946fa..b1bdd06d 100644 --- a/rest/engine.go +++ b/rest/engine.go @@ -290,14 +290,13 @@ func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain.Chai decrypters[fingerprint] = decrypter } - return func(chn chain.Chain) chain.Chain { + var unsignedCallbacks []handler.UnsignedCallback if ng.unsignedCallback != nil { - return chn.Append(handler.ContentSecurityHandler( - decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback)) + unsignedCallbacks = append(unsignedCallbacks, ng.unsignedCallback) } - return chn.Append(handler.ContentSecurityHandler(decrypters, signature.Expiry, signature.Strict)) + return chn.Append(handler.LimitContentSecurityHandler(ng.conf.MaxBytes, decrypters, signature.Expiry, signature.Strict, unsignedCallbacks)) }, nil } diff --git a/rest/engine_test.go b/rest/engine_test.go index ece1d0d8..e9c7cebd 100644 --- a/rest/engine_test.go +++ b/rest/engine_test.go @@ -6,16 +6,40 @@ import ( "fmt" "net/http" "net/http/httptest" + "os" "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" "github.com/zeromicro/go-zero/core/conf" + "github.com/zeromicro/go-zero/core/fs" "github.com/zeromicro/go-zero/core/logx" ) +const ( + priKey = `-----BEGIN RSA PRIVATE KEY----- +MIICXQIBAAKBgQC4TJk3onpqb2RYE3wwt23J9SHLFstHGSkUYFLe+nl1dEKHbD+/ +Zt95L757J3xGTrwoTc7KCTxbrgn+stn0w52BNjj/kIE2ko4lbh/v8Fl14AyVR9ms +fKtKOnhe5FCT72mdtApr+qvzcC3q9hfXwkyQU32pv7q5UimZ205iKSBmgQIDAQAB +AoGAM5mWqGIAXj5z3MkP01/4CDxuyrrGDVD5FHBno3CDgyQa4Gmpa4B0/ywj671B +aTnwKmSmiiCN2qleuQYASixes2zY5fgTzt+7KNkl9JHsy7i606eH2eCKzsUa/s6u +WD8V3w/hGCQ9zYI18ihwyXlGHIgcRz/eeRh+nWcWVJzGOPUCQQD5nr6It/1yHb1p +C6l4fC4xXF19l4KxJjGu1xv/sOpSx0pOqBDEX3Mh//FU954392rUWDXV1/I65BPt +TLphdsu3AkEAvQJ2Qay/lffFj9FaUrvXuftJZ/Ypn0FpaSiUh3Ak3obBT6UvSZS0 +bcYdCJCNHDtBOsWHnIN1x+BcWAPrdU7PhwJBAIQ0dUlH2S3VXnoCOTGc44I1Hzbj +Rc65IdsuBqA3fQN2lX5vOOIog3vgaFrOArg1jBkG1wx5IMvb/EnUN2pjVqUCQCza +KLXtCInOAlPemlCHwumfeAvznmzsWNdbieOZ+SXVVIpR6KbNYwOpv7oIk3Pfm9sW +hNffWlPUKhW42Gc+DIECQQDmk20YgBXwXWRM5DRPbhisIV088N5Z58K9DtFWkZsd +OBDT3dFcgZONtlmR1MqZO0pTh30lA4qovYj3Bx7A8i36 +-----END RSA PRIVATE KEY-----` +) + func TestNewEngine(t *testing.T) { + priKeyfile, err := fs.TempFilenameWithText(priKey) + assert.Nil(t, err) + defer os.Remove(priKeyfile) + yamls := []string{ `Name: foo Host: localhost @@ -151,6 +175,29 @@ Verbose: true Handler: func(w http.ResponseWriter, r *http.Request) {}, }}, }, + { + priority: true, + jwt: jwtSetting{ + enabled: true, + }, + signature: signatureSetting{ + enabled: true, + SignatureConf: SignatureConf{ + Strict: true, + PrivateKeys: []PrivateKeyConf{ + { + Fingerprint: "a", + KeyFile: priKeyfile, + }, + }, + }, + }, + routes: []Route{{ + Method: http.MethodGet, + Path: "/", + Handler: func(w http.ResponseWriter, r *http.Request) {}, + }}, + }, } for _, yaml := range yamls { diff --git a/rest/handler/contentsecurityhandler.go b/rest/handler/contentsecurityhandler.go index 1f8dc39a..ca89dc68 100644 --- a/rest/handler/contentsecurityhandler.go +++ b/rest/handler/contentsecurityhandler.go @@ -18,6 +18,12 @@ type UnsignedCallback func(w http.ResponseWriter, r *http.Request, next http.Han // ContentSecurityHandler returns a middleware to verify content security. func ContentSecurityHandler(decrypters map[string]codec.RsaDecrypter, tolerance time.Duration, strict bool, callbacks ...UnsignedCallback) func(http.Handler) http.Handler { + return LimitContentSecurityHandler(maxBytes, decrypters, tolerance, strict, callbacks) +} + +// LimitContentSecurityHandler returns a middleware to verify content security. +func LimitContentSecurityHandler(maxBytesSize int64, decrypters map[string]codec.RsaDecrypter, tolerance time.Duration, + strict bool, callbacks []UnsignedCallback) func(http.Handler) http.Handler { if len(callbacks) == 0 { callbacks = append(callbacks, handleVerificationFailure) } @@ -36,7 +42,7 @@ func ContentSecurityHandler(decrypters map[string]codec.RsaDecrypter, tolerance r.Header.Get(contentSecurity)) executeCallbacks(w, r, next, strict, code, callbacks) } else if r.ContentLength > 0 && header.Encrypted() { - CryptionHandler(header.Key)(next).ServeHTTP(w, r) + LimitCryptionHandler(maxBytesSize, header.Key)(next).ServeHTTP(w, r) } else { next.ServeHTTP(w, r) } diff --git a/rest/handler/cryptionhandler.go b/rest/handler/cryptionhandler.go index 9b91c840..df54ee46 100644 --- a/rest/handler/cryptionhandler.go +++ b/rest/handler/cryptionhandler.go @@ -19,6 +19,11 @@ var errContentLengthExceeded = errors.New("content length exceeded") // CryptionHandler returns a middleware to handle cryption. func CryptionHandler(key []byte) func(http.Handler) http.Handler { + return LimitCryptionHandler(maxBytes, key) +} + +// LimitCryptionHandler returns a middleware to handle cryption. +func LimitCryptionHandler(maxBytesSize int64, key []byte) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { cw := newCryptionResponseWriter(w) @@ -29,7 +34,7 @@ func CryptionHandler(key []byte) func(http.Handler) http.Handler { return } - if err := decryptBody(key, r); err != nil { + if err := decryptBody(maxBytesSize, key, r); err != nil { w.WriteHeader(http.StatusBadRequest) return } @@ -39,8 +44,8 @@ func CryptionHandler(key []byte) func(http.Handler) http.Handler { } } -func decryptBody(key []byte, r *http.Request) error { - if r.ContentLength > maxBytes { +func decryptBody(maxBytesSize int64, key []byte, r *http.Request) error { + if maxBytesSize > 0 && r.ContentLength > maxBytesSize { return errContentLengthExceeded }