You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-zero/rest/internal/security/contentsecurity.go

162 lines
4.2 KiB
Go

4 years ago
package security
4 years ago
import (
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/tal-tech/go-zero/core/codec"
"github.com/tal-tech/go-zero/core/iox"
"github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/rest/httpx"
4 years ago
)
const (
requestUriHeader = "X-Request-Uri"
signatureField = "signature"
timeField = "time"
)
var (
// ErrInvalidContentType is an error that indicates invalid content type.
4 years ago
ErrInvalidContentType = errors.New("invalid content type")
// ErrInvalidHeader is an error that indicates invalid X-Content-Security header.
ErrInvalidHeader = errors.New("invalid X-Content-Security header")
// ErrInvalidKey is an error that indicates invalid key.
ErrInvalidKey = errors.New("invalid key")
// ErrInvalidPublicKey is an error that indicates invalid public key.
ErrInvalidPublicKey = errors.New("invalid public key")
// ErrInvalidSecret is an error that indicates invalid secret.
ErrInvalidSecret = errors.New("invalid secret")
4 years ago
)
// A ContentSecurityHeader is a content security header.
4 years ago
type ContentSecurityHeader struct {
Key []byte
Timestamp string
ContentType int
Signature string
}
// Encrypted checks if it's a crypted request.
4 years ago
func (h *ContentSecurityHeader) Encrypted() bool {
return h.ContentType == httpx.CryptionType
}
// ParseContentSecurity parses content security settings in give r.
4 years ago
func ParseContentSecurity(decrypters map[string]codec.RsaDecrypter, r *http.Request) (
*ContentSecurityHeader, error) {
contentSecurity := r.Header.Get(httpx.ContentSecurity)
attrs := httpx.ParseHeader(contentSecurity)
fingerprint := attrs[httpx.KeyField]
secret := attrs[httpx.SecretField]
signature := attrs[signatureField]
if len(fingerprint) == 0 || len(secret) == 0 || len(signature) == 0 {
return nil, ErrInvalidHeader
}
decrypter, ok := decrypters[fingerprint]
if !ok {
return nil, ErrInvalidPublicKey
}
decryptedSecret, err := decrypter.DecryptBase64(secret)
if err != nil {
return nil, ErrInvalidSecret
}
attrs = httpx.ParseHeader(string(decryptedSecret))
base64Key := attrs[httpx.KeyField]
timestamp := attrs[timeField]
contentType := attrs[httpx.TypeField]
key, err := base64.StdEncoding.DecodeString(base64Key)
if err != nil {
return nil, ErrInvalidKey
}
cType, err := strconv.Atoi(contentType)
if err != nil {
return nil, ErrInvalidContentType
}
return &ContentSecurityHeader{
Key: key,
Timestamp: timestamp,
ContentType: cType,
Signature: signature,
}, nil
}
// VerifySignature verifies the signature in given r.
4 years ago
func VerifySignature(r *http.Request, securityHeader *ContentSecurityHeader, tolerance time.Duration) int {
seconds, err := strconv.ParseInt(securityHeader.Timestamp, 10, 64)
if err != nil {
return httpx.CodeSignatureInvalidHeader
}
now := time.Now().Unix()
toleranceSeconds := int64(tolerance.Seconds())
if seconds+toleranceSeconds < now || now+toleranceSeconds < seconds {
return httpx.CodeSignatureWrongTime
}
reqPath, reqQuery := getPathQuery(r)
signContent := strings.Join([]string{
securityHeader.Timestamp,
r.Method,
reqPath,
reqQuery,
computeBodySignature(r),
}, "\n")
actualSignature := codec.HmacBase64(securityHeader.Key, signContent)
/*passed := securityHeader.Signature == actualSignature
4 years ago
if !passed {
logx.Infof("signature different, expect: %s, actual: %s",
securityHeader.Signature, actualSignature)
}
if passed {
return httpx.CodeSignaturePass
}*/
if securityHeader.Signature == actualSignature {
return httpx.CodeSignaturePass
4 years ago
}
logx.Infof("signature different, expect: %s, actual: %s",
securityHeader.Signature, actualSignature)
return httpx.CodeSignatureInvalidToken
4 years ago
}
func computeBodySignature(r *http.Request) string {
var dup io.ReadCloser
r.Body, dup = iox.DupReadCloser(r.Body)
sha := sha256.New()
io.Copy(sha, r.Body)
r.Body = dup
return fmt.Sprintf("%x", sha.Sum(nil))
}
func getPathQuery(r *http.Request) (string, string) {
requestUri := r.Header.Get(requestUriHeader)
if len(requestUri) == 0 {
return r.URL.Path, r.URL.RawQuery
}
uri, err := url.Parse(requestUri)
if err != nil {
return r.URL.Path, r.URL.RawQuery
}
return uri.Path, uri.RawQuery
}