You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-zero/rest/handler/cryptionhandler_test.go

163 lines
4.9 KiB
Go

4 years ago
package handler
4 years ago
import (
"bytes"
"crypto/rand"
4 years ago
"encoding/base64"
"io"
4 years ago
"net/http"
"net/http/httptest"
"testing"
"testing/iotest"
4 years ago
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/codec"
4 years ago
)
const (
reqText = "ping"
respText = "pong"
)
var aesKey = []byte(`PdSgVkYp3s6v9y$B&E)H+MbQeThWmZq4`)
func TestCryptionHandlerGet(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/any", http.NoBody)
4 years ago
handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte(respText))
w.Header().Set("X-Test", "test")
assert.Nil(t, err)
}))
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
assert.Nil(t, err)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "test", recorder.Header().Get("X-Test"))
assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
}
func TestCryptionHandlerGet_badKey(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/any", http.NoBody)
handler := CryptionHandler(append(aesKey, aesKey...))(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte(respText))
w.Header().Set("X-Test", "test")
assert.Nil(t, err)
}))
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusInternalServerError, recorder.Code)
}
4 years ago
func TestCryptionHandlerPost(t *testing.T) {
var buf bytes.Buffer
enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
assert.Nil(t, err)
buf.WriteString(base64.StdEncoding.EncodeToString(enc))
req := httptest.NewRequest(http.MethodPost, "/any", &buf)
handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
4 years ago
assert.Nil(t, err)
assert.Equal(t, reqText, string(body))
w.Write([]byte(respText))
}))
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
assert.Nil(t, err)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
}
func TestCryptionHandlerPostBadEncryption(t *testing.T) {
var buf bytes.Buffer
enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
assert.Nil(t, err)
buf.Write(enc)
req := httptest.NewRequest(http.MethodPost, "/any", &buf)
handler := CryptionHandler(aesKey)(nil)
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestCryptionHandlerWriteHeader(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/any", http.NoBody)
4 years ago
handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}))
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
}
func TestCryptionHandlerFlush(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/any", http.NoBody)
handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(respText))
flusher, ok := w.(http.Flusher)
assert.True(t, ok)
flusher.Flush()
}))
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
assert.Nil(t, err)
assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
}
func TestCryptionHandler_Hijack(t *testing.T) {
resp := httptest.NewRecorder()
writer := newCryptionResponseWriter(resp)
assert.NotPanics(t, func() {
writer.Hijack()
})
writer = newCryptionResponseWriter(mockedHijackable{resp})
assert.NotPanics(t, func() {
writer.Hijack()
})
}
func TestCryptionHandler_ContentTooLong(t *testing.T) {
handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
svr := httptest.NewServer(handler)
defer svr.Close()
body := make([]byte, maxBytes+1)
_, err := rand.Read(body)
assert.NoError(t, err)
req, err := http.NewRequest(http.MethodPost, svr.URL, bytes.NewReader(body))
assert.Nil(t, err)
resp, err := http.DefaultClient.Do(req)
assert.Nil(t, err)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
}
func TestCryptionHandler_BadBody(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "/foo", iotest.ErrReader(io.ErrUnexpectedEOF))
assert.Nil(t, err)
err = decryptBody(maxBytes, aesKey, req)
assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
}
func TestCryptionHandler_BadKey(t *testing.T) {
var buf bytes.Buffer
enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
assert.Nil(t, err)
buf.WriteString(base64.StdEncoding.EncodeToString(enc))
req := httptest.NewRequest(http.MethodPost, "/any", &buf)
err = decryptBody(maxBytes, append(aesKey, aesKey...), req)
assert.Error(t, err)
}