fix golint issues in rest (#529)

master
Kevin Wan 4 years ago committed by GitHub
parent d894b88c3e
commit 655ae8034c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -7,17 +7,20 @@ import (
) )
type ( type (
// A PrivateKeyConf is a private key config.
PrivateKeyConf struct { PrivateKeyConf struct {
Fingerprint string Fingerprint string
KeyFile string KeyFile string
} }
// A SignatureConf is a signature config.
SignatureConf struct { SignatureConf struct {
Strict bool `json:",default=false"` Strict bool `json:",default=false"`
Expiry time.Duration `json:",default=1h"` Expiry time.Duration `json:",default=1h"`
PrivateKeys []PrivateKeyConf PrivateKeys []PrivateKeyConf
} }
// A RestConf is a http service config.
// Why not name it as Conf, because we need to consider usage like: // Why not name it as Conf, because we need to consider usage like:
// type Config struct { // type Config struct {
// zrpc.RpcConf // zrpc.RpcConf

@ -19,6 +19,7 @@ import (
// use 1000m to represent 100% // use 1000m to represent 100%
const topCpuUsage = 1000 const topCpuUsage = 1000
// ErrSignatureConfig is an error that indicates bad config for signature.
var ErrSignatureConfig = errors.New("bad config for Signature") var ErrSignatureConfig = errors.New("bad config for Signature")
type engine struct { type engine struct {
@ -114,7 +115,7 @@ func (s *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat
handler.TimeoutHandler(time.Duration(s.conf.Timeout)*time.Millisecond), handler.TimeoutHandler(time.Duration(s.conf.Timeout)*time.Millisecond),
handler.RecoverHandler, handler.RecoverHandler,
handler.MetricHandler(metrics), handler.MetricHandler(metrics),
handler.PromethousHandler(route.Path), handler.PrometheusHandler(route.Path),
handler.MaxBytesHandler(s.conf.MaxBytes), handler.MaxBytesHandler(s.conf.MaxBytes),
handler.GunzipHandler, handler.GunzipHandler,
) )

@ -28,15 +28,19 @@ var (
) )
type ( type (
// A AuthorizeOptions is authorize options.
AuthorizeOptions struct { AuthorizeOptions struct {
PrevSecret string PrevSecret string
Callback UnauthorizedCallback Callback UnauthorizedCallback
} }
// UnauthorizedCallback defines the method of unauthorized callback.
UnauthorizedCallback func(w http.ResponseWriter, r *http.Request, err error) UnauthorizedCallback func(w http.ResponseWriter, r *http.Request, err error)
// AuthorizeOption defines the method to customize an AuthorizeOptions.
AuthorizeOption func(opts *AuthorizeOptions) AuthorizeOption func(opts *AuthorizeOptions)
) )
// Authorize returns an authorize middleware.
func Authorize(secret string, opts ...AuthorizeOption) func(http.Handler) http.Handler { func Authorize(secret string, opts ...AuthorizeOption) func(http.Handler) http.Handler {
var authOpts AuthorizeOptions var authOpts AuthorizeOptions
for _, opt := range opts { for _, opt := range opts {
@ -78,12 +82,14 @@ func Authorize(secret string, opts ...AuthorizeOption) func(http.Handler) http.H
} }
} }
// WithPrevSecret returns an AuthorizeOption with setting previous secret.
func WithPrevSecret(secret string) AuthorizeOption { func WithPrevSecret(secret string) AuthorizeOption {
return func(opts *AuthorizeOptions) { return func(opts *AuthorizeOptions) {
opts.PrevSecret = secret opts.PrevSecret = secret
} }
} }
// WithUnauthorizedCallback returns an AuthorizeOption with setting unauthorized callback.
func WithUnauthorizedCallback(callback UnauthorizedCallback) AuthorizeOption { func WithUnauthorizedCallback(callback UnauthorizedCallback) AuthorizeOption {
return func(opts *AuthorizeOptions) { return func(opts *AuthorizeOptions) {
opts.Callback = callback opts.Callback = callback

@ -14,6 +14,7 @@ import (
const breakerSeparator = "://" const breakerSeparator = "://"
// BreakerHandler returns a break circuit middleware.
func BreakerHandler(method, path string, metrics *stat.Metrics) func(http.Handler) http.Handler { func BreakerHandler(method, path string, metrics *stat.Metrics) func(http.Handler) http.Handler {
brk := breaker.NewBreaker(breaker.WithName(strings.Join([]string{method, path}, breakerSeparator))) brk := breaker.NewBreaker(breaker.WithName(strings.Join([]string{method, path}, breakerSeparator)))
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {

@ -12,8 +12,10 @@ import (
const contentSecurity = "X-Content-Security" const contentSecurity = "X-Content-Security"
// UnsignedCallback defines the method of the unsigned callback.
type UnsignedCallback func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) type UnsignedCallback func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int)
// ContentSecurityHandler returns a middleware to verify content security.
func ContentSecurityHandler(decrypters map[string]codec.RsaDecrypter, tolerance time.Duration, func ContentSecurityHandler(decrypters map[string]codec.RsaDecrypter, tolerance time.Duration,
strict bool, callbacks ...UnsignedCallback) func(http.Handler) http.Handler { strict bool, callbacks ...UnsignedCallback) func(http.Handler) http.Handler {
if len(callbacks) == 0 { if len(callbacks) == 0 {

@ -16,6 +16,7 @@ const maxBytes = 1 << 20 // 1 MiB
var errContentLengthExceeded = errors.New("content length exceeded") var errContentLengthExceeded = errors.New("content length exceeded")
// CryptionHandler returns a middleware to handle cryption.
func CryptionHandler(key []byte) func(http.Handler) http.Handler { func CryptionHandler(key []byte) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

@ -10,6 +10,7 @@ import (
const gzipEncoding = "gzip" const gzipEncoding = "gzip"
// GunzipHandler returns a middleware to gunzip http request body.
func GunzipHandler(next http.Handler) http.Handler { func GunzipHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.Header.Get(httpx.ContentEncoding), gzipEncoding) { if strings.Contains(r.Header.Get(httpx.ContentEncoding), gzipEncoding) {

@ -19,36 +19,37 @@ import (
const slowThreshold = time.Millisecond * 500 const slowThreshold = time.Millisecond * 500
type LoggedResponseWriter struct { type loggedResponseWriter struct {
w http.ResponseWriter w http.ResponseWriter
r *http.Request r *http.Request
code int code int
} }
func (w *LoggedResponseWriter) Header() http.Header { func (w *loggedResponseWriter) Header() http.Header {
return w.w.Header() return w.w.Header()
} }
func (w *LoggedResponseWriter) Write(bytes []byte) (int, error) { func (w *loggedResponseWriter) Write(bytes []byte) (int, error) {
return w.w.Write(bytes) return w.w.Write(bytes)
} }
func (w *LoggedResponseWriter) WriteHeader(code int) { func (w *loggedResponseWriter) WriteHeader(code int) {
w.w.WriteHeader(code) w.w.WriteHeader(code)
w.code = code w.code = code
} }
func (w *LoggedResponseWriter) Flush() { func (w *loggedResponseWriter) Flush() {
if flusher, ok := w.w.(http.Flusher); ok { if flusher, ok := w.w.(http.Flusher); ok {
flusher.Flush() flusher.Flush()
} }
} }
// LogHandler returns a middleware that logs http request and response.
func LogHandler(next http.Handler) http.Handler { func LogHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
timer := utils.NewElapsedTimer() timer := utils.NewElapsedTimer()
logs := new(internal.LogCollector) logs := new(internal.LogCollector)
lrw := LoggedResponseWriter{ lrw := loggedResponseWriter{
w: w, w: w,
r: r, r: r,
code: http.StatusOK, code: http.StatusOK,
@ -62,40 +63,41 @@ func LogHandler(next http.Handler) http.Handler {
}) })
} }
type DetailLoggedResponseWriter struct { type detailLoggedResponseWriter struct {
writer *LoggedResponseWriter writer *loggedResponseWriter
buf *bytes.Buffer buf *bytes.Buffer
} }
func newDetailLoggedResponseWriter(writer *LoggedResponseWriter, buf *bytes.Buffer) *DetailLoggedResponseWriter { func newDetailLoggedResponseWriter(writer *loggedResponseWriter, buf *bytes.Buffer) *detailLoggedResponseWriter {
return &DetailLoggedResponseWriter{ return &detailLoggedResponseWriter{
writer: writer, writer: writer,
buf: buf, buf: buf,
} }
} }
func (w *DetailLoggedResponseWriter) Flush() { func (w *detailLoggedResponseWriter) Flush() {
w.writer.Flush() w.writer.Flush()
} }
func (w *DetailLoggedResponseWriter) Header() http.Header { func (w *detailLoggedResponseWriter) Header() http.Header {
return w.writer.Header() return w.writer.Header()
} }
func (w *DetailLoggedResponseWriter) Write(bs []byte) (int, error) { func (w *detailLoggedResponseWriter) Write(bs []byte) (int, error) {
w.buf.Write(bs) w.buf.Write(bs)
return w.writer.Write(bs) return w.writer.Write(bs)
} }
func (w *DetailLoggedResponseWriter) WriteHeader(code int) { func (w *detailLoggedResponseWriter) WriteHeader(code int) {
w.writer.WriteHeader(code) w.writer.WriteHeader(code)
} }
// DetailedLogHandler returns a middleware that logs http request and response in details.
func DetailedLogHandler(next http.Handler) http.Handler { func DetailedLogHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
timer := utils.NewElapsedTimer() timer := utils.NewElapsedTimer()
var buf bytes.Buffer var buf bytes.Buffer
lrw := newDetailLoggedResponseWriter(&LoggedResponseWriter{ lrw := newDetailLoggedResponseWriter(&loggedResponseWriter{
w: w, w: w,
r: r, r: r,
code: http.StatusOK, code: http.StatusOK,
@ -146,7 +148,7 @@ func logBrief(r *http.Request, code int, timer *utils.ElapsedTimer, logs *intern
} }
} }
func logDetails(r *http.Request, response *DetailLoggedResponseWriter, timer *utils.ElapsedTimer, func logDetails(r *http.Request, response *detailLoggedResponseWriter, timer *utils.ElapsedTimer,
logs *internal.LogCollector) { logs *internal.LogCollector) {
var buf bytes.Buffer var buf bytes.Buffer
duration := timer.Duration() duration := timer.Duration()

@ -6,6 +6,7 @@ import (
"github.com/tal-tech/go-zero/rest/internal" "github.com/tal-tech/go-zero/rest/internal"
) )
// MaxBytesHandler returns a middleware that limit reading of http request body.
func MaxBytesHandler(n int64) func(http.Handler) http.Handler { func MaxBytesHandler(n int64) func(http.Handler) http.Handler {
if n <= 0 { if n <= 0 {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {

@ -8,6 +8,7 @@ import (
"github.com/tal-tech/go-zero/rest/internal" "github.com/tal-tech/go-zero/rest/internal"
) )
// MaxConns returns a middleware that limit the concurrent connections.
func MaxConns(n int) func(http.Handler) http.Handler { func MaxConns(n int) func(http.Handler) http.Handler {
if n <= 0 { if n <= 0 {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {

@ -7,6 +7,7 @@ import (
"github.com/tal-tech/go-zero/core/timex" "github.com/tal-tech/go-zero/core/timex"
) )
// MetricHandler returns a middleware that stat the metrics.
func MetricHandler(metrics *stat.Metrics) func(http.Handler) http.Handler { func MetricHandler(metrics *stat.Metrics) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

@ -31,7 +31,8 @@ var (
}) })
) )
func PromethousHandler(path string) func(http.Handler) http.Handler { // PrometheusHandler returns a middleware that reports stats to prometheus.
func PrometheusHandler(path string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
startTime := timex.Now() startTime := timex.Now()

@ -9,7 +9,7 @@ import (
) )
func TestPromMetricHandler(t *testing.T) { func TestPromMetricHandler(t *testing.T) {
promMetricHandler := PromethousHandler("/user/login") promMetricHandler := PrometheusHandler("/user/login")
handler := promMetricHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := promMetricHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
})) }))

@ -8,6 +8,7 @@ import (
"github.com/tal-tech/go-zero/rest/internal" "github.com/tal-tech/go-zero/rest/internal"
) )
// RecoverHandler returns a middleware that recovers if panic happens.
func RecoverHandler(next http.Handler) http.Handler { func RecoverHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() { defer func() {

@ -18,6 +18,7 @@ var (
lock sync.Mutex lock sync.Mutex
) )
// SheddingHandler returns a middleware that does load shedding.
func SheddingHandler(shedder load.Shedder, metrics *stat.Metrics) func(http.Handler) http.Handler { func SheddingHandler(shedder load.Shedder, metrics *stat.Metrics) func(http.Handler) http.Handler {
if shedder == nil { if shedder == nil {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {

@ -8,6 +8,7 @@ import (
"github.com/tal-tech/go-zero/core/trace" "github.com/tal-tech/go-zero/core/trace"
) )
// TracingHandler returns a middleware that traces the request.
func TracingHandler(next http.Handler) http.Handler { func TracingHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
carrier, err := trace.Extract(trace.HttpFormat, r.Header) carrier, err := trace.Extract(trace.HttpFormat, r.Header)

@ -60,6 +60,7 @@ func ParseForm(r *http.Request, v interface{}) error {
return formUnmarshaler.Unmarshal(params, v) return formUnmarshaler.Unmarshal(params, v)
} }
// ParseHeader parses the request header and returns a map.
func ParseHeader(headerValue string) map[string]string { func ParseHeader(headerValue string) map[string]string {
ret := make(map[string]string) ret := make(map[string]string)
fields := strings.Split(headerValue, separator) fields := strings.Split(headerValue, separator)

@ -13,6 +13,7 @@ var (
lock sync.RWMutex lock sync.RWMutex
) )
// Error writes err into w.
func Error(w http.ResponseWriter, err error) { func Error(w http.ResponseWriter, err error) {
lock.RLock() lock.RLock()
handler := errorHandler handler := errorHandler
@ -32,20 +33,24 @@ func Error(w http.ResponseWriter, err error) {
} }
} }
// Ok writes HTTP 200 OK into w.
func Ok(w http.ResponseWriter) { func Ok(w http.ResponseWriter) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
} }
// OkJson writes v into w with 200 OK.
func OkJson(w http.ResponseWriter, v interface{}) { func OkJson(w http.ResponseWriter, v interface{}) {
WriteJson(w, http.StatusOK, v) WriteJson(w, http.StatusOK, v)
} }
// SetErrorHandler sets the error handler, which is called on calling Error.
func SetErrorHandler(handler func(error) (int, interface{})) { func SetErrorHandler(handler func(error) (int, interface{})) {
lock.Lock() lock.Lock()
defer lock.Unlock() defer lock.Unlock()
errorHandler = handler errorHandler = handler
} }
// WriteJson writes v as json string into w with code.
func WriteJson(w http.ResponseWriter, code int, v interface{}) { func WriteJson(w http.ResponseWriter, code int, v interface{}) {
w.Header().Set(ContentType, ApplicationJson) w.Header().Set(ContentType, ApplicationJson)
w.WriteHeader(code) w.WriteHeader(code)

@ -2,6 +2,7 @@ package httpx
import "net/http" import "net/http"
// Router interface represents a http router that handles http requests.
type Router interface { type Router interface {
http.Handler http.Handler
Handle(method string, path string, handler http.Handler) error Handle(method string, path string, handler http.Handler) error

@ -1,19 +1,31 @@
package httpx package httpx
const ( const (
// ApplicationJson means application/json.
ApplicationJson = "application/json" ApplicationJson = "application/json"
// ContentEncoding means Content-Encoding.
ContentEncoding = "Content-Encoding" ContentEncoding = "Content-Encoding"
// ContentSecurity means X-Content-Security.
ContentSecurity = "X-Content-Security" ContentSecurity = "X-Content-Security"
// ContentType means Content-Type.
ContentType = "Content-Type" ContentType = "Content-Type"
// KeyField means key.
KeyField = "key" KeyField = "key"
// SecretField means secret.
SecretField = "secret" SecretField = "secret"
// TypeField means type.
TypeField = "type" TypeField = "type"
// CryptionType means cryption.
CryptionType = 1 CryptionType = 1
) )
const ( const (
// CodeSignaturePass means signature verification passed.
CodeSignaturePass = iota CodeSignaturePass = iota
// CodeSignatureInvalidHeader means invalid header in signature.
CodeSignatureInvalidHeader CodeSignatureInvalidHeader
// CodeSignatureWrongTime means wrong timestamp in signature.
CodeSignatureWrongTime CodeSignatureWrongTime
// CodeSignatureInvalidToken means invalid token in signature.
CodeSignatureInvalidToken CodeSignatureInvalidToken
) )

@ -7,6 +7,7 @@ import (
var pathVars = contextKey("pathVars") var pathVars = contextKey("pathVars")
// Vars parses path variables and returns a map.
func Vars(r *http.Request) map[string]string { func Vars(r *http.Request) map[string]string {
vars, ok := r.Context().Value(pathVars).(map[string]string) vars, ok := r.Context().Value(pathVars).(map[string]string)
if ok { if ok {
@ -16,6 +17,7 @@ func Vars(r *http.Request) map[string]string {
return nil return nil
} }
// WithPathVars writes params into given r and returns a new http.Request.
func WithPathVars(r *http.Request, params map[string]string) *http.Request { func WithPathVars(r *http.Request, params map[string]string) *http.Request {
return r.WithContext(context.WithValue(r.Context(), pathVars, params)) return r.WithContext(context.WithValue(r.Context(), pathVars, params))
} }

@ -10,19 +10,23 @@ import (
"github.com/tal-tech/go-zero/rest/httpx" "github.com/tal-tech/go-zero/rest/httpx"
) )
// LogContext is a context key.
var LogContext = contextKey("request_logs") var LogContext = contextKey("request_logs")
// A LogCollector is used to collect logs.
type LogCollector struct { type LogCollector struct {
Messages []string Messages []string
lock sync.Mutex lock sync.Mutex
} }
// Append appends msg into log context.
func (lc *LogCollector) Append(msg string) { func (lc *LogCollector) Append(msg string) {
lc.lock.Lock() lc.lock.Lock()
lc.Messages = append(lc.Messages, msg) lc.Messages = append(lc.Messages, msg)
lc.lock.Unlock() lc.lock.Unlock()
} }
// Flush flushes collected logs.
func (lc *LogCollector) Flush() string { func (lc *LogCollector) Flush() string {
var buffer bytes.Buffer var buffer bytes.Buffer
@ -48,18 +52,22 @@ func (lc *LogCollector) takeAll() []string {
return messages return messages
} }
// Error logs the given v along with r in error log.
func Error(r *http.Request, v ...interface{}) { func Error(r *http.Request, v ...interface{}) {
logx.ErrorCaller(1, format(r, v...)) logx.ErrorCaller(1, format(r, v...))
} }
// Errorf logs the given v with format along with r in error log.
func Errorf(r *http.Request, format string, v ...interface{}) { func Errorf(r *http.Request, format string, v ...interface{}) {
logx.ErrorCaller(1, formatf(r, format, v...)) logx.ErrorCaller(1, formatf(r, format, v...))
} }
// Info logs the given v along with r in access log.
func Info(r *http.Request, v ...interface{}) { func Info(r *http.Request, v ...interface{}) {
appendLog(r, format(r, v...)) appendLog(r, format(r, v...))
} }
// Infof logs the given v with format along with r in access log.
func Infof(r *http.Request, format string, v ...interface{}) { func Infof(r *http.Request, format string, v ...interface{}) {
appendLog(r, formatf(r, format, v...)) appendLog(r, formatf(r, format, v...))
} }

@ -25,13 +25,19 @@ const (
) )
var ( var (
// ErrInvalidContentType is an error that indicates invalid content type.
ErrInvalidContentType = errors.New("invalid content type") ErrInvalidContentType = errors.New("invalid content type")
// ErrInvalidHeader is an error that indicates invalid X-Content-Security header.
ErrInvalidHeader = errors.New("invalid X-Content-Security header") ErrInvalidHeader = errors.New("invalid X-Content-Security header")
// ErrInvalidKey is an error that indicates invalid key.
ErrInvalidKey = errors.New("invalid key") ErrInvalidKey = errors.New("invalid key")
// ErrInvalidPublicKey is an error that indicates invalid public key.
ErrInvalidPublicKey = errors.New("invalid public key") ErrInvalidPublicKey = errors.New("invalid public key")
// ErrInvalidSecret is an error that indicates invalid secret.
ErrInvalidSecret = errors.New("invalid secret") ErrInvalidSecret = errors.New("invalid secret")
) )
// A ContentSecurityHeader is a content security header.
type ContentSecurityHeader struct { type ContentSecurityHeader struct {
Key []byte Key []byte
Timestamp string Timestamp string
@ -39,10 +45,12 @@ type ContentSecurityHeader struct {
Signature string Signature string
} }
// Encrypted checks if it's a crypted request.
func (h *ContentSecurityHeader) Encrypted() bool { func (h *ContentSecurityHeader) Encrypted() bool {
return h.ContentType == httpx.CryptionType return h.ContentType == httpx.CryptionType
} }
// ParseContentSecurity parses content security settings in give r.
func ParseContentSecurity(decrypters map[string]codec.RsaDecrypter, r *http.Request) ( func ParseContentSecurity(decrypters map[string]codec.RsaDecrypter, r *http.Request) (
*ContentSecurityHeader, error) { *ContentSecurityHeader, error) {
contentSecurity := r.Header.Get(httpx.ContentSecurity) contentSecurity := r.Header.Get(httpx.ContentSecurity)
@ -88,6 +96,7 @@ func ParseContentSecurity(decrypters map[string]codec.RsaDecrypter, r *http.Requ
}, nil }, nil
} }
// VerifySignature verifies the signature in given r.
func VerifySignature(r *http.Request, securityHeader *ContentSecurityHeader, tolerance time.Duration) int { func VerifySignature(r *http.Request, securityHeader *ContentSecurityHeader, tolerance time.Duration) int {
seconds, err := strconv.ParseInt(securityHeader.Timestamp, 10, 64) seconds, err := strconv.ParseInt(securityHeader.Timestamp, 10, 64)
if err != nil { if err != nil {

@ -2,25 +2,30 @@ package security
import "net/http" import "net/http"
// A WithCodeResponseWriter is a helper to delay sealing a http.ResponseWriter on writing code.
type WithCodeResponseWriter struct { type WithCodeResponseWriter struct {
Writer http.ResponseWriter Writer http.ResponseWriter
Code int Code int
} }
// Flush flushes the response writer.
func (w *WithCodeResponseWriter) Flush() { func (w *WithCodeResponseWriter) Flush() {
if flusher, ok := w.Writer.(http.Flusher); ok { if flusher, ok := w.Writer.(http.Flusher); ok {
flusher.Flush() flusher.Flush()
} }
} }
// Header returns the http header.
func (w *WithCodeResponseWriter) Header() http.Header { func (w *WithCodeResponseWriter) Header() http.Header {
return w.Writer.Header() return w.Writer.Header()
} }
// Write writes bytes into w.
func (w *WithCodeResponseWriter) Write(bytes []byte) (int, error) { func (w *WithCodeResponseWriter) Write(bytes []byte) (int, error) {
return w.Writer.Write(bytes) return w.Writer.Write(bytes)
} }
// WriteHeader writes code into w, and not sealing the writer.
func (w *WithCodeResponseWriter) WriteHeader(code int) { func (w *WithCodeResponseWriter) WriteHeader(code int) {
w.Writer.WriteHeader(code) w.Writer.WriteHeader(code)
w.Code = code w.Code = code

@ -8,12 +8,14 @@ import (
"github.com/tal-tech/go-zero/core/proc" "github.com/tal-tech/go-zero/core/proc"
) )
// StartHttp starts a http server.
func StartHttp(host string, port int, handler http.Handler) error { func StartHttp(host string, port int, handler http.Handler) error {
return start(host, port, handler, func(srv *http.Server) error { return start(host, port, handler, func(srv *http.Server) error {
return srv.ListenAndServe() return srv.ListenAndServe()
}) })
} }
// StartHttps starts a https server.
func StartHttps(host string, port int, certFile, keyFile string, handler http.Handler) error { func StartHttps(host string, port int, certFile, keyFile string, handler http.Handler) error {
return start(host, port, handler, func(srv *http.Server) error { return start(host, port, handler, func(srv *http.Server) error {
// certFile and keyFile are set in buildHttpsServer // certFile and keyFile are set in buildHttpsServer

@ -17,7 +17,9 @@ const (
) )
var ( var (
// ErrInvalidMethod is an error that indicates not a valid http method.
ErrInvalidMethod = errors.New("not a valid http method") ErrInvalidMethod = errors.New("not a valid http method")
// ErrInvalidPath is an error that indicates path is not start with /.
ErrInvalidPath = errors.New("path must begin with '/'") ErrInvalidPath = errors.New("path must begin with '/'")
) )
@ -27,6 +29,7 @@ type patRouter struct {
notAllowed http.Handler notAllowed http.Handler
} }
// NewRouter returns a httpx.Router.
func NewRouter() httpx.Router { func NewRouter() httpx.Router {
return &patRouter{ return &patRouter{
trees: make(map[string]*search.Tree), trees: make(map[string]*search.Tree),

@ -15,8 +15,10 @@ type (
start func(*engine) error start func(*engine) error
} }
// RunOption defines the method to customize a Server.
RunOption func(*Server) RunOption func(*Server)
// A Server is a http server.
Server struct { Server struct {
ngin *engine ngin *engine
opts runOptions opts runOptions
@ -58,6 +60,7 @@ func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
return server, nil return server, nil
} }
// AddRoutes add given routes into the Server.
func (e *Server) AddRoutes(rs []Route, opts ...RouteOption) { func (e *Server) AddRoutes(rs []Route, opts ...RouteOption) {
r := featuredRoutes{ r := featuredRoutes{
routes: rs, routes: rs,
@ -68,28 +71,34 @@ func (e *Server) AddRoutes(rs []Route, opts ...RouteOption) {
e.ngin.AddRoutes(r) e.ngin.AddRoutes(r)
} }
// AddRoute adds given route into the Server.
func (e *Server) AddRoute(r Route, opts ...RouteOption) { func (e *Server) AddRoute(r Route, opts ...RouteOption) {
e.AddRoutes([]Route{r}, opts...) e.AddRoutes([]Route{r}, opts...)
} }
// Start starts the Server.
func (e *Server) Start() { func (e *Server) Start() {
handleError(e.opts.start(e.ngin)) handleError(e.opts.start(e.ngin))
} }
// Stop stops the Server.
func (e *Server) Stop() { func (e *Server) Stop() {
logx.Close() logx.Close()
} }
// Use adds the given middleware in the Server.
func (e *Server) Use(middleware Middleware) { func (e *Server) Use(middleware Middleware) {
e.ngin.use(middleware) e.ngin.use(middleware)
} }
// ToMiddleware converts the given handler to a Middleware.
func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware { func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
return func(handle http.HandlerFunc) http.HandlerFunc { return func(handle http.HandlerFunc) http.HandlerFunc {
return handler(handle).ServeHTTP return handler(handle).ServeHTTP
} }
} }
// WithJwt returns a func to enable jwt authentication in given route.
func WithJwt(secret string) RouteOption { func WithJwt(secret string) RouteOption {
return func(r *featuredRoutes) { return func(r *featuredRoutes) {
validateSecret(secret) validateSecret(secret)
@ -98,6 +107,8 @@ func WithJwt(secret string) RouteOption {
} }
} }
// WithJwtTransition returns a func to enable jwt authentication as well as jwt secret transition.
// Which means old and new jwt secrets work together for a peroid.
func WithJwtTransition(secret, prevSecret string) RouteOption { func WithJwtTransition(secret, prevSecret string) RouteOption {
return func(r *featuredRoutes) { return func(r *featuredRoutes) {
// why not validate prevSecret, because prevSecret is an already used one, // why not validate prevSecret, because prevSecret is an already used one,
@ -109,6 +120,7 @@ func WithJwtTransition(secret, prevSecret string) RouteOption {
} }
} }
// WithMiddlewares adds given middlewares to given routes.
func WithMiddlewares(ms []Middleware, rs ...Route) []Route { func WithMiddlewares(ms []Middleware, rs ...Route) []Route {
for i := len(ms) - 1; i >= 0; i-- { for i := len(ms) - 1; i >= 0; i-- {
rs = WithMiddleware(ms[i], rs...) rs = WithMiddleware(ms[i], rs...)
@ -116,6 +128,7 @@ func WithMiddlewares(ms []Middleware, rs ...Route) []Route {
return rs return rs
} }
// WithMiddleware adds given middleware to given route.
func WithMiddleware(middleware Middleware, rs ...Route) []Route { func WithMiddleware(middleware Middleware, rs ...Route) []Route {
routes := make([]Route, len(rs)) routes := make([]Route, len(rs))
@ -131,24 +144,28 @@ func WithMiddleware(middleware Middleware, rs ...Route) []Route {
return routes return routes
} }
// WithNotFoundHandler returns a RunOption with not found handler set to given handler.
func WithNotFoundHandler(handler http.Handler) RunOption { func WithNotFoundHandler(handler http.Handler) RunOption {
rt := router.NewRouter() rt := router.NewRouter()
rt.SetNotFoundHandler(handler) rt.SetNotFoundHandler(handler)
return WithRouter(rt) return WithRouter(rt)
} }
// WithNotAllowedHandler returns a RunOption with not allowed handler set to given handler.
func WithNotAllowedHandler(handler http.Handler) RunOption { func WithNotAllowedHandler(handler http.Handler) RunOption {
rt := router.NewRouter() rt := router.NewRouter()
rt.SetNotAllowedHandler(handler) rt.SetNotAllowedHandler(handler)
return WithRouter(rt) return WithRouter(rt)
} }
// WithPriority returns a RunOption with priority.
func WithPriority() RouteOption { func WithPriority() RouteOption {
return func(r *featuredRoutes) { return func(r *featuredRoutes) {
r.priority = true r.priority = true
} }
} }
// WithRouter returns a RunOption that make server run with given router.
func WithRouter(router httpx.Router) RunOption { func WithRouter(router httpx.Router) RunOption {
return func(server *Server) { return func(server *Server) {
server.opts.start = func(srv *engine) error { server.opts.start = func(srv *engine) error {
@ -157,6 +174,7 @@ func WithRouter(router httpx.Router) RunOption {
} }
} }
// WithSignature returns a RouteOption to enable signature verification.
func WithSignature(signature SignatureConf) RouteOption { func WithSignature(signature SignatureConf) RouteOption {
return func(r *featuredRoutes) { return func(r *featuredRoutes) {
r.signature.enabled = true r.signature.enabled = true
@ -166,12 +184,14 @@ func WithSignature(signature SignatureConf) RouteOption {
} }
} }
// WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set.
func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption { func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
return func(engine *Server) { return func(engine *Server) {
engine.ngin.SetUnauthorizedCallback(callback) engine.ngin.SetUnauthorizedCallback(callback)
} }
} }
// WithUnsignedCallback returns a RunOption that with given unsigned callback set.
func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption { func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption {
return func(engine *Server) { return func(engine *Server) {
engine.ngin.SetUnsignedCallback(callback) engine.ngin.SetUnsignedCallback(callback)

@ -14,8 +14,10 @@ import (
const claimHistoryResetDuration = time.Hour * 24 const claimHistoryResetDuration = time.Hour * 24
type ( type (
// ParseOption defines the method to customize a TokenParser.
ParseOption func(parser *TokenParser) ParseOption func(parser *TokenParser)
// A TokenParser is used to parse tokens.
TokenParser struct { TokenParser struct {
resetTime time.Duration resetTime time.Duration
resetDuration time.Duration resetDuration time.Duration
@ -23,6 +25,7 @@ type (
} }
) )
// NewTokenParser returns a TokenParser.
func NewTokenParser(opts ...ParseOption) *TokenParser { func NewTokenParser(opts ...ParseOption) *TokenParser {
parser := &TokenParser{ parser := &TokenParser{
resetTime: timex.Now(), resetTime: timex.Now(),
@ -36,6 +39,7 @@ func NewTokenParser(opts ...ParseOption) *TokenParser {
return parser return parser
} }
// ParseToken parses token from given r, with passed in secret and prevSecret.
func (tp *TokenParser) ParseToken(r *http.Request, secret, prevSecret string) (*jwt.Token, error) { func (tp *TokenParser) ParseToken(r *http.Request, secret, prevSecret string) (*jwt.Token, error) {
var token *jwt.Token var token *jwt.Token
var err error var err error
@ -108,6 +112,7 @@ func (tp *TokenParser) loadCount(secret string) uint64 {
return 0 return 0
} }
// WithResetDuration returns a func to customize a TokenParser with reset duration.
func WithResetDuration(duration time.Duration) ParseOption { func WithResetDuration(duration time.Duration) ParseOption {
return func(parser *TokenParser) { return func(parser *TokenParser) {
parser.resetDuration = duration parser.resetDuration = duration

@ -3,14 +3,17 @@ package rest
import "net/http" import "net/http"
type ( type (
// Middleware defines the middleware method.
Middleware func(next http.HandlerFunc) http.HandlerFunc Middleware func(next http.HandlerFunc) http.HandlerFunc
// A Route is a http route.
Route struct { Route struct {
Method string Method string
Path string Path string
Handler http.HandlerFunc Handler http.HandlerFunc
} }
// RouteOption defines the method to customize a featured route.
RouteOption func(r *featuredRoutes) RouteOption func(r *featuredRoutes)
jwtSetting struct { jwtSetting struct {

Loading…
Cancel
Save