master
kevin 4 years ago
parent 121323b8c3
commit ca3934582a

@ -8,11 +8,11 @@ import (
"time" "time"
"zero/core/executors" "zero/core/executors"
"zero/core/httpx"
"zero/core/logx" "zero/core/logx"
"zero/example/graceful/dns/api/svc" "zero/example/graceful/dns/api/svc"
"zero/example/graceful/dns/api/types" "zero/example/graceful/dns/api/types"
"zero/example/graceful/dns/rpc/graceful" "zero/example/graceful/dns/rpc/graceful"
"zero/ngin/httpx"
) )
func gracefulHandler(ctx *svc.ServiceContext) http.HandlerFunc { func gracefulHandler(ctx *svc.ServiceContext) http.HandlerFunc {

@ -8,11 +8,11 @@ import (
"time" "time"
"zero/core/executors" "zero/core/executors"
"zero/core/httpx"
"zero/core/logx" "zero/core/logx"
"zero/example/graceful/etcd/api/svc" "zero/example/graceful/etcd/api/svc"
"zero/example/graceful/etcd/api/types" "zero/example/graceful/etcd/api/types"
"zero/example/graceful/etcd/rpc/graceful" "zero/example/graceful/etcd/rpc/graceful"
"zero/ngin/httpx"
) )
func gracefulHandler(ctx *svc.ServiceContext) http.HandlerFunc { func gracefulHandler(ctx *svc.ServiceContext) http.HandlerFunc {

@ -4,10 +4,10 @@ import (
"flag" "flag"
"net/http" "net/http"
"zero/core/httpx"
"zero/core/logx" "zero/core/logx"
"zero/core/service" "zero/core/service"
"zero/ngin" "zero/ngin"
"zero/ngin/httpx"
) )
var ( var (

@ -5,10 +5,10 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"zero/core/httpx"
"zero/core/logx" "zero/core/logx"
"zero/core/service" "zero/core/service"
"zero/ngin" "zero/ngin"
"zero/ngin/httpx"
) )
var ( var (

@ -6,10 +6,10 @@ import (
"net/http" "net/http"
"time" "time"
"zero/core/httpx"
"zero/core/logx" "zero/core/logx"
"zero/core/service" "zero/core/service"
"zero/ngin" "zero/ngin"
"zero/ngin/httpx"
) )
var ( var (

@ -5,10 +5,10 @@ import (
"io" "io"
"net/http" "net/http"
"zero/core/httpx"
"zero/core/logx" "zero/core/logx"
"zero/core/service" "zero/core/service"
"zero/ngin" "zero/ngin"
"zero/ngin/httpx"
) )
var keyPem = flag.String("prikey", "private.pem", "the private key file") var keyPem = flag.String("prikey", "private.pem", "the private key file")

@ -9,8 +9,8 @@ import (
"time" "time"
"zero/core/conf" "zero/core/conf"
"zero/core/httpx"
"zero/ngin" "zero/ngin"
"zero/ngin/httpx"
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
"github.com/dgrijalva/jwt-go/request" "github.com/dgrijalva/jwt-go/request"

@ -6,11 +6,11 @@ import (
"net/http" "net/http"
"zero/core/conf" "zero/core/conf"
"zero/core/httpx"
"zero/core/logx" "zero/core/logx"
"zero/core/service" "zero/core/service"
"zero/example/tracing/remote/portal" "zero/example/tracing/remote/portal"
"zero/ngin" "zero/ngin"
"zero/ngin/httpx"
"zero/rpcx" "zero/rpcx"
) )

@ -1,12 +1,12 @@
package httphandler package handler
import ( import (
"context" "context"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"zero/core/httpsecurity"
"zero/core/logx" "zero/core/logx"
"zero/ngin/internal"
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
) )
@ -37,7 +37,7 @@ func Authorize(secret string, opts ...AuthorizeOption) func(http.Handler) http.H
opt(&authOpts) opt(&authOpts)
} }
parser := httpsecurity.NewTokenParser() parser := internal.NewTokenParser()
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, err := parser.ParseToken(r, secret, authOpts.PrevSecret) token, err := parser.ParseToken(r, secret, authOpts.PrevSecret)

@ -1,4 +1,4 @@
package httphandler package handler
import ( import (
"net/http" "net/http"

@ -1,4 +1,4 @@
package httphandler package handler
import ( import (
"fmt" "fmt"
@ -6,10 +6,10 @@ import (
"strings" "strings"
"zero/core/breaker" "zero/core/breaker"
"zero/core/httphandler/internal"
"zero/core/httpx"
"zero/core/logx" "zero/core/logx"
"zero/core/stat" "zero/core/stat"
"zero/ngin/internal"
"zero/ngin/internal/security"
) )
const breakerSeparator = "://" const breakerSeparator = "://"
@ -22,12 +22,12 @@ func BreakerHandler(method, path string, metrics *stat.Metrics) func(http.Handle
if err != nil { if err != nil {
metrics.AddDrop() metrics.AddDrop()
logx.Errorf("[http] dropped, %s - %s - %s", logx.Errorf("[http] dropped, %s - %s - %s",
r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent()) r.RequestURI, internal.GetRemoteAddr(r), r.UserAgent())
w.WriteHeader(http.StatusServiceUnavailable) w.WriteHeader(http.StatusServiceUnavailable)
return return
} }
cw := &internal.WithCodeResponseWriter{Writer: w} cw := &security.WithCodeResponseWriter{Writer: w}
defer func() { defer func() {
if cw.Code < http.StatusInternalServerError { if cw.Code < http.StatusInternalServerError {
promise.Accept() promise.Accept()

@ -1,4 +1,4 @@
package httphandler package handler
import ( import (
"fmt" "fmt"

@ -1,13 +1,13 @@
package httphandler package handler
import ( import (
"net/http" "net/http"
"time" "time"
"zero/core/codec" "zero/core/codec"
"zero/core/httphandler/internal"
"zero/core/httpx"
"zero/core/logx" "zero/core/logx"
"zero/ngin/httpx"
"zero/ngin/internal/security"
) )
const contentSecurity = "X-Content-Security" const contentSecurity = "X-Content-Security"
@ -24,12 +24,12 @@ func ContentSecurityHandler(decrypters map[string]codec.RsaDecrypter, tolerance
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method { switch r.Method {
case http.MethodDelete, http.MethodGet, http.MethodPost, http.MethodPut: case http.MethodDelete, http.MethodGet, http.MethodPost, http.MethodPut:
header, err := internal.ParseContentSecurity(decrypters, r) header, err := security.ParseContentSecurity(decrypters, r)
if err != nil { if err != nil {
logx.Infof("Signature parse failed, X-Content-Security: %s, error: %s", logx.Infof("Signature parse failed, X-Content-Security: %s, error: %s",
r.Header.Get(contentSecurity), err.Error()) r.Header.Get(contentSecurity), err.Error())
executeCallbacks(w, r, next, strict, httpx.CodeSignatureInvalidHeader, callbacks) executeCallbacks(w, r, next, strict, httpx.CodeSignatureInvalidHeader, callbacks)
} else if code := internal.VerifySignature(r, header, tolerance); code != httpx.CodeSignaturePass { } else if code := security.VerifySignature(r, header, tolerance); code != httpx.CodeSignaturePass {
logx.Infof("Signature verification failed, X-Content-Security: %s", logx.Infof("Signature verification failed, X-Content-Security: %s",
r.Header.Get(contentSecurity)) r.Header.Get(contentSecurity))
executeCallbacks(w, r, next, strict, code, callbacks) executeCallbacks(w, r, next, strict, code, callbacks)

@ -1,4 +1,4 @@
package httphandler package handler
import ( import (
"bytes" "bytes"
@ -18,7 +18,7 @@ import (
"time" "time"
"zero/core/codec" "zero/core/codec"
"zero/core/httpx" "zero/ngin/httpx"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )

@ -1,4 +1,4 @@
package httphandler package handler
import ( import (
"bytes" "bytes"

@ -1,4 +1,4 @@
package httphandler package handler
import ( import (
"bytes" "bytes"

@ -1,11 +1,11 @@
package httphandler package handler
import ( import (
"compress/gzip" "compress/gzip"
"net/http" "net/http"
"strings" "strings"
"zero/core/httpx" "zero/ngin/httpx"
) )
const gzipEncoding = "gzip" const gzipEncoding = "gzip"

@ -1,4 +1,4 @@
package httphandler package handler
import ( import (
"bytes" "bytes"
@ -10,7 +10,7 @@ import (
"testing" "testing"
"zero/core/codec" "zero/core/codec"
"zero/core/httpx" "zero/ngin/httpx"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )

@ -1,4 +1,4 @@
package httphandler package handler
import ( import (
"bytes" "bytes"
@ -9,12 +9,11 @@ import (
"net/http/httputil" "net/http/httputil"
"time" "time"
"zero/core/httplog"
"zero/core/httpx"
"zero/core/iox" "zero/core/iox"
"zero/core/logx" "zero/core/logx"
"zero/core/timex" "zero/core/timex"
"zero/core/utils" "zero/core/utils"
"zero/ngin/internal"
) )
const slowThreshold = time.Millisecond * 500 const slowThreshold = time.Millisecond * 500
@ -41,7 +40,7 @@ func (w *LoggedResponseWriter) WriteHeader(code int) {
func LogHandler(next http.Handler) http.Handler { func LogHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
timer := utils.NewElapsedTimer() timer := utils.NewElapsedTimer()
logs := new(httplog.LogCollector) logs := new(internal.LogCollector)
lrw := LoggedResponseWriter{ lrw := LoggedResponseWriter{
w: w, w: w,
r: r, r: r,
@ -50,7 +49,7 @@ func LogHandler(next http.Handler) http.Handler {
var dup io.ReadCloser var dup io.ReadCloser
r.Body, dup = iox.DupReadCloser(r.Body) r.Body, dup = iox.DupReadCloser(r.Body)
next.ServeHTTP(&lrw, r.WithContext(context.WithValue(r.Context(), httplog.LogContext, logs))) next.ServeHTTP(&lrw, r.WithContext(context.WithValue(r.Context(), internal.LogContext, logs)))
r.Body = dup r.Body = dup
logBrief(r, lrw.code, timer, logs) logBrief(r, lrw.code, timer, logs)
}) })
@ -93,8 +92,8 @@ func DetailedLogHandler(next http.Handler) http.Handler {
var dup io.ReadCloser var dup io.ReadCloser
r.Body, dup = iox.DupReadCloser(r.Body) r.Body, dup = iox.DupReadCloser(r.Body)
logs := new(httplog.LogCollector) logs := new(internal.LogCollector)
next.ServeHTTP(lrw, r.WithContext(context.WithValue(r.Context(), httplog.LogContext, logs))) next.ServeHTTP(lrw, r.WithContext(context.WithValue(r.Context(), internal.LogContext, logs)))
r.Body = dup r.Body = dup
logDetails(r, lrw, timer, logs) logDetails(r, lrw, timer, logs)
}) })
@ -109,14 +108,14 @@ func dumpRequest(r *http.Request) string {
} }
} }
func logBrief(r *http.Request, code int, timer *utils.ElapsedTimer, logs *httplog.LogCollector) { func logBrief(r *http.Request, code int, timer *utils.ElapsedTimer, logs *internal.LogCollector) {
var buf bytes.Buffer var buf bytes.Buffer
duration := timer.Duration() duration := timer.Duration()
buf.WriteString(fmt.Sprintf("%d - %s - %s - %s - %s", buf.WriteString(fmt.Sprintf("%d - %s - %s - %s - %s",
code, r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent(), timex.ReprOfDuration(duration))) code, r.RequestURI, internal.GetRemoteAddr(r), r.UserAgent(), timex.ReprOfDuration(duration)))
if duration > slowThreshold { if duration > slowThreshold {
logx.Slowf("[HTTP] %d - %s - %s - %s - slowcall(%s)", logx.Slowf("[HTTP] %d - %s - %s - %s - slowcall(%s)",
code, r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent(), timex.ReprOfDuration(duration)) code, r.RequestURI, internal.GetRemoteAddr(r), r.UserAgent(), timex.ReprOfDuration(duration))
} }
ok := isOkResponse(code) ok := isOkResponse(code)
@ -137,7 +136,7 @@ func logBrief(r *http.Request, code int, timer *utils.ElapsedTimer, logs *httplo
} }
func logDetails(r *http.Request, response *DetailLoggedResponseWriter, timer *utils.ElapsedTimer, func logDetails(r *http.Request, response *DetailLoggedResponseWriter, timer *utils.ElapsedTimer,
logs *httplog.LogCollector) { logs *internal.LogCollector) {
var buf bytes.Buffer var buf bytes.Buffer
duration := timer.Duration() duration := timer.Duration()
buf.WriteString(fmt.Sprintf("%d - %s - %s\n=> %s\n", buf.WriteString(fmt.Sprintf("%d - %s - %s\n=> %s\n",

@ -1,4 +1,4 @@
package httphandler package handler
import ( import (
"io/ioutil" "io/ioutil"
@ -8,9 +8,9 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert" "zero/ngin/internal"
"zero/core/httplog" "github.com/stretchr/testify/assert"
) )
func init() { func init() {
@ -26,7 +26,7 @@ func TestLogHandler(t *testing.T) {
for _, logHandler := range handlers { for _, logHandler := range handlers {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Context().Value(httplog.LogContext).(*httplog.LogCollector).Append("anything") r.Context().Value(internal.LogContext).(*internal.LogCollector).Append("anything")
w.Header().Set("X-Test", "test") w.Header().Set("X-Test", "test")
w.WriteHeader(http.StatusServiceUnavailable) w.WriteHeader(http.StatusServiceUnavailable)
_, err := w.Write([]byte("content")) _, err := w.Write([]byte("content"))

@ -1,9 +1,9 @@
package httphandler package handler
import ( import (
"net/http" "net/http"
"zero/core/httplog" "zero/ngin/internal"
) )
func MaxBytesHandler(n int64) func(http.Handler) http.Handler { func MaxBytesHandler(n int64) func(http.Handler) http.Handler {
@ -16,7 +16,7 @@ func MaxBytesHandler(n int64) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.ContentLength > n { if r.ContentLength > n {
httplog.Errorf(r, "request entity too large, limit is %d, but got %d, rejected with code %d", internal.Errorf(r, "request entity too large, limit is %d, but got %d, rejected with code %d",
n, r.ContentLength, http.StatusRequestEntityTooLarge) n, r.ContentLength, http.StatusRequestEntityTooLarge)
w.WriteHeader(http.StatusRequestEntityTooLarge) w.WriteHeader(http.StatusRequestEntityTooLarge)
} else { } else {

@ -1,4 +1,4 @@
package httphandler package handler
import ( import (
"bytes" "bytes"

@ -1,11 +1,11 @@
package httphandler package handler
import ( import (
"net/http" "net/http"
"zero/core/httplog"
"zero/core/logx" "zero/core/logx"
"zero/core/syncx" "zero/core/syncx"
"zero/ngin/internal"
) )
func MaxConns(n int) func(http.Handler) http.Handler { func MaxConns(n int) func(http.Handler) http.Handler {
@ -28,7 +28,7 @@ func MaxConns(n int) func(http.Handler) http.Handler {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
} else { } else {
httplog.Errorf(r, "Concurrent connections over %d, rejected with code %d", internal.Errorf(r, "Concurrent connections over %d, rejected with code %d",
n, http.StatusServiceUnavailable) n, http.StatusServiceUnavailable)
w.WriteHeader(http.StatusServiceUnavailable) w.WriteHeader(http.StatusServiceUnavailable)
} }

@ -1,4 +1,4 @@
package httphandler package handler
import ( import (
"io/ioutil" "io/ioutil"

@ -1,4 +1,4 @@
package httphandler package handler
import ( import (
"net/http" "net/http"

@ -1,4 +1,4 @@
package httphandler package handler
import ( import (
"net/http" "net/http"

@ -1,13 +1,13 @@
package httphandler package handler
import ( import (
"net/http" "net/http"
"strconv" "strconv"
"time" "time"
"zero/core/httphandler/internal"
"zero/core/metric" "zero/core/metric"
"zero/core/timex" "zero/core/timex"
"zero/ngin/internal/security"
) )
const serverNamespace = "http_server" const serverNamespace = "http_server"
@ -35,7 +35,7 @@ func PromMetricHandler(path string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
startTime := timex.Now() startTime := timex.Now()
cw := &internal.WithCodeResponseWriter{Writer: w} cw := &security.WithCodeResponseWriter{Writer: w}
defer func() { defer func() {
metricServerReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), path) metricServerReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), path)
metricServerReqCodeTotal.Inc(path, strconv.Itoa(cw.Code)) metricServerReqCodeTotal.Inc(path, strconv.Itoa(cw.Code))

@ -1,4 +1,4 @@
package httphandler package handler
import ( import (
"net/http" "net/http"

@ -1,18 +1,18 @@
package httphandler package handler
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"runtime/debug" "runtime/debug"
"zero/core/httplog" "zero/ngin/internal"
) )
func RecoverHandler(next http.Handler) http.Handler { func RecoverHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() { defer func() {
if result := recover(); result != nil { if result := recover(); result != nil {
httplog.Error(r, fmt.Sprintf("%v\n%s", result, debug.Stack())) internal.Error(r, fmt.Sprintf("%v\n%s", result, debug.Stack()))
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
} }
}() }()

@ -1,4 +1,4 @@
package httphandler package handler
import ( import (
"io/ioutil" "io/ioutil"

@ -1,14 +1,14 @@
package httphandler package handler
import ( import (
"net/http" "net/http"
"sync" "sync"
"zero/core/httphandler/internal"
"zero/core/httpx"
"zero/core/load" "zero/core/load"
"zero/core/logx" "zero/core/logx"
"zero/core/stat" "zero/core/stat"
"zero/ngin/internal"
"zero/ngin/internal/security"
) )
const serviceType = "api" const serviceType = "api"
@ -35,12 +35,12 @@ func SheddingHandler(shedder load.Shedder, metrics *stat.Metrics) func(http.Hand
metrics.AddDrop() metrics.AddDrop()
sheddingStat.IncrementDrop() sheddingStat.IncrementDrop()
logx.Errorf("[http] dropped, %s - %s - %s", logx.Errorf("[http] dropped, %s - %s - %s",
r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent()) r.RequestURI, internal.GetRemoteAddr(r), r.UserAgent())
w.WriteHeader(http.StatusServiceUnavailable) w.WriteHeader(http.StatusServiceUnavailable)
return return
} }
cw := &internal.WithCodeResponseWriter{Writer: w} cw := &security.WithCodeResponseWriter{Writer: w}
defer func() { defer func() {
if cw.Code == http.StatusServiceUnavailable { if cw.Code == http.StatusServiceUnavailable {
promise.Fail() promise.Fail()

@ -1,4 +1,4 @@
package httphandler package handler
import ( import (
"io/ioutil" "io/ioutil"

@ -1,4 +1,4 @@
package httphandler package handler
import ( import (
"net/http" "net/http"

@ -1,4 +1,4 @@
package httphandler package handler
import ( import (
"io/ioutil" "io/ioutil"

@ -1,4 +1,4 @@
package httphandler package handler
import ( import (
"net/http" "net/http"

@ -1,4 +1,4 @@
package httphandler package handler
import ( import (
"net/http" "net/http"

@ -1,18 +1,16 @@
package httpx package httpx
import ( import (
"errors"
"io" "io"
"net/http" "net/http"
"strings" "strings"
"zero/core/httprouter"
"zero/core/mapping" "zero/core/mapping"
"zero/ngin/internal/context"
) )
const ( const (
multipartFormData = "multipart/form-data" multipartFormData = "multipart/form-data"
xForwardFor = "X-Forward-For"
formKey = "form" formKey = "form"
pathKey = "path" pathKey = "path"
emptyJson = "{}" emptyJson = "{}"
@ -23,21 +21,10 @@ const (
) )
var ( var (
ErrBodylessRequest = errors.New("not a POST|PUT|PATCH request")
formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues()) formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues())
pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues()) pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues())
) )
// Returns the peer address, supports X-Forward-For
func GetRemoteAddr(r *http.Request) string {
v := r.Header.Get(xForwardFor)
if len(v) > 0 {
return v
}
return r.RemoteAddr
}
func Parse(r *http.Request, v interface{}) error { func Parse(r *http.Request, v interface{}) error {
if err := ParsePath(r, v); err != nil { if err := ParsePath(r, v); err != nil {
return err return err
@ -110,7 +97,7 @@ func ParseJsonBody(r *http.Request, v interface{}) error {
// Parses the symbols reside in url path. // Parses the symbols reside in url path.
// Like http://localhost/bag/:name // Like http://localhost/bag/:name
func ParsePath(r *http.Request, v interface{}) error { func ParsePath(r *http.Request, v interface{}) error {
vars := httprouter.Vars(r) vars := context.Vars(r)
m := make(map[string]interface{}, len(vars)) m := make(map[string]interface{}, len(vars))
for k, v := range vars { for k, v := range vars {
m[k] = v m[k] = v

@ -10,7 +10,7 @@ import (
"strings" "strings"
"testing" "testing"
"zero/core/httprouter" "zero/ngin/internal/router"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -20,15 +20,6 @@ const (
contentLength = "Content-Length" contentLength = "Content-Length"
) )
func TestGetRemoteAddr(t *testing.T) {
host := "8.8.8.8"
r, err := http.NewRequest(http.MethodGet, "/", strings.NewReader(""))
assert.Nil(t, err)
r.Header.Set(xForwardFor, host)
assert.Equal(t, host, GetRemoteAddr(r))
}
func TestParseForm(t *testing.T) { func TestParseForm(t *testing.T) {
var v struct { var v struct {
Name string `form:"name"` Name string `form:"name"`
@ -135,8 +126,8 @@ func TestParseSlice(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded") r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
router := httprouter.NewPatRouter() rt := router.NewPatRouter()
err = router.Handle(http.MethodPost, "/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err = rt.Handle(http.MethodPost, "/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
Names []string `form:"names"` Names []string `form:"names"`
}{} }{}
@ -150,7 +141,7 @@ func TestParseSlice(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
router.ServeHTTP(rr, r) rt.ServeHTTP(rr, r)
} }
func TestParseJsonPost(t *testing.T) { func TestParseJsonPost(t *testing.T) {
@ -159,7 +150,7 @@ func TestParseJsonPost(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
r.Header.Set(ContentType, ApplicationJson) r.Header.Set(ContentType, ApplicationJson)
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func(
w http.ResponseWriter, r *http.Request) { w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@ -191,7 +182,7 @@ func TestParseJsonPostWithIntSlice(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
r.Header.Set(ContentType, ApplicationJson) r.Header.Set(ContentType, ApplicationJson)
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func(
w http.ResponseWriter, r *http.Request) { w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@ -219,7 +210,7 @@ func TestParseJsonPostError(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
r.Header.Set(ContentType, ApplicationJson) r.Header.Set(ContentType, ApplicationJson)
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@ -247,7 +238,7 @@ func TestParseJsonPostInvalidRequest(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
r.Header.Set(ContentType, ApplicationJson) r.Header.Set(ContentType, ApplicationJson)
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodPost, "/", http.HandlerFunc( err = router.Handle(http.MethodPost, "/", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@ -269,7 +260,7 @@ func TestParseJsonPostRequired(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
r.Header.Set(ContentType, ApplicationJson) r.Header.Set(ContentType, ApplicationJson)
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@ -292,7 +283,7 @@ func TestParsePath(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017", nil) r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017", nil)
assert.Nil(t, err) assert.Nil(t, err)
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@ -317,7 +308,7 @@ func TestParsePathRequired(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin", nil) r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin", nil)
assert.Nil(t, err) assert.Nil(t, err)
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodGet, "/:name/", http.HandlerFunc( err = router.Handle(http.MethodGet, "/:name/", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@ -338,7 +329,7 @@ func TestParseQuery(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil) r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil)
assert.Nil(t, err) assert.Nil(t, err)
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@ -363,7 +354,7 @@ func TestParseQueryRequired(t *testing.T) {
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever", nil) r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever", nil)
assert.Nil(t, err) assert.Nil(t, err)
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
Nickname string `form:"nickname"` Nickname string `form:"nickname"`
@ -383,7 +374,7 @@ func TestParseOptional(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=", nil) r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=", nil)
assert.Nil(t, err) assert.Nil(t, err)
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@ -424,7 +415,7 @@ func TestParseNestedInRequestEmpty(t *testing.T) {
} }
) )
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
var v WrappedRequest var v WrappedRequest
@ -463,7 +454,7 @@ func TestParsePtrInRequest(t *testing.T) {
} }
) )
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
var v WrappedRequest var v WrappedRequest
@ -494,7 +485,7 @@ func TestParsePtrInRequestEmpty(t *testing.T) {
} }
) )
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodPost, "/kevin", http.HandlerFunc( err = router.Handle(http.MethodPost, "/kevin", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
var v WrappedRequest var v WrappedRequest
@ -511,7 +502,7 @@ func TestParseQueryOptional(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=", nil) r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=", nil)
assert.Nil(t, err) assert.Nil(t, err)
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@ -536,7 +527,7 @@ func TestParse(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil) r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil)
assert.Nil(t, err) assert.Nil(t, err)
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@ -574,7 +565,7 @@ func TestParseWrappedRequest(t *testing.T) {
} }
) )
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
var v WrappedRequest var v WrappedRequest
@ -606,7 +597,7 @@ func TestParseWrappedGetRequestWithJsonHeader(t *testing.T) {
} }
) )
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
var v WrappedRequest var v WrappedRequest
@ -639,7 +630,7 @@ func TestParseWrappedHeadRequestWithJsonHeader(t *testing.T) {
} }
) )
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodHead, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodHead, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
var v WrappedRequest var v WrappedRequest
@ -671,7 +662,7 @@ func TestParseWrappedRequestPtr(t *testing.T) {
} }
) )
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
var v WrappedRequest var v WrappedRequest
@ -694,7 +685,7 @@ func TestParseWithAll(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
r.Header.Set(ContentType, ApplicationJson) r.Header.Set(ContentType, ApplicationJson)
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
Name string `path:"name"` Name string `path:"name"`
@ -725,7 +716,7 @@ func TestParseWithAllUtf8(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
r.Header.Set(ContentType, applicationJsonWithUtf8) r.Header.Set(ContentType, applicationJsonWithUtf8)
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@ -756,7 +747,7 @@ func TestParseWithMissingForm(t *testing.T) {
bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`)) bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`))
assert.Nil(t, err) assert.Nil(t, err)
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@ -783,7 +774,7 @@ func TestParseWithMissingAllForms(t *testing.T) {
bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`)) bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`))
assert.Nil(t, err) assert.Nil(t, err)
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@ -809,7 +800,7 @@ func TestParseWithMissingJson(t *testing.T) {
bytes.NewBufferString(`{"location": "shanghai"}`)) bytes.NewBufferString(`{"location": "shanghai"}`))
assert.Nil(t, err) assert.Nil(t, err)
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@ -835,7 +826,7 @@ func TestParseWithMissingAllJsons(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil) r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil)
assert.Nil(t, err) assert.Nil(t, err)
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@ -862,7 +853,7 @@ func TestParseWithMissingPath(t *testing.T) {
bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`)) bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`))
assert.Nil(t, err) assert.Nil(t, err)
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@ -889,7 +880,7 @@ func TestParseWithMissingAllPaths(t *testing.T) {
bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`)) bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`))
assert.Nil(t, err) assert.Nil(t, err)
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@ -916,7 +907,7 @@ func TestParseGetWithContentLengthHeader(t *testing.T) {
r.Header.Set(ContentType, ApplicationJson) r.Header.Set(ContentType, ApplicationJson)
r.Header.Set(contentLength, "1024") r.Header.Set(contentLength, "1024")
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@ -943,7 +934,7 @@ func TestParseJsonPostWithTypeMismatch(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
r.Header.Set(ContentType, applicationJsonWithUtf8) r.Header.Set(ContentType, applicationJsonWithUtf8)
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@ -969,7 +960,7 @@ func TestParseJsonPostWithInt2String(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
r.Header.Set(ContentType, applicationJsonWithUtf8) r.Header.Set(ContentType, applicationJsonWithUtf8)
router := httprouter.NewPatRouter() router := router.NewPatRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {

@ -0,0 +1,21 @@
package context
import (
"context"
"net/http"
)
const pathVars = "pathVars"
func Vars(r *http.Request) map[string]string {
vars, ok := r.Context().Value(pathVars).(map[string]string)
if ok {
return vars
}
return nil
}
func WithPathVars(r *http.Request, params map[string]string) *http.Request {
return r.WithContext(context.WithValue(r.Context(), pathVars, params))
}

@ -1,4 +1,4 @@
package httplog package internal
import ( import (
"bytes" "bytes"
@ -6,7 +6,6 @@ import (
"net/http" "net/http"
"sync" "sync"
"zero/core/httpx"
"zero/core/logx" "zero/core/logx"
) )
@ -80,5 +79,5 @@ func formatf(r *http.Request, format string, v ...interface{}) string {
} }
func formatWithReq(r *http.Request, v string) string { func formatWithReq(r *http.Request, v string) string {
return fmt.Sprintf("(%s - %s) %s", r.RequestURI, httpx.GetRemoteAddr(r), v) return fmt.Sprintf("(%s - %s) %s", r.RequestURI, GetRemoteAddr(r), v)
} }

@ -1,4 +1,4 @@
package httplog package internal
import ( import (
"context" "context"

@ -1,18 +1,17 @@
package httprouter package router
import ( import (
"context"
"net/http" "net/http"
"path" "path"
"strings" "strings"
"zero/core/search" "zero/core/search"
"zero/ngin/internal/context"
) )
const ( const (
allowHeader = "Allow" allowHeader = "Allow"
allowMethodSeparator = ", " allowMethodSeparator = ", "
pathVars = "pathVars"
) )
type PatRouter struct { type PatRouter struct {
@ -50,7 +49,7 @@ func (pr *PatRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if tree, ok := pr.trees[r.Method]; ok { if tree, ok := pr.trees[r.Method]; ok {
if result, ok := tree.Search(reqPath); ok { if result, ok := tree.Search(reqPath); ok {
if len(result.Params) > 0 { if len(result.Params) > 0 {
r = r.WithContext(context.WithValue(r.Context(), pathVars, result.Params)) r = context.WithPathVars(r, result.Params)
} }
result.Item.(http.Handler).ServeHTTP(w, r) result.Item.(http.Handler).ServeHTTP(w, r)
return return
@ -98,15 +97,6 @@ func (pr *PatRouter) methodNotAllowed(method, path string) (string, bool) {
} }
} }
func Vars(r *http.Request) map[string]string {
vars, ok := r.Context().Value(pathVars).(map[string]string)
if ok {
return vars
}
return nil
}
func validMethod(method string) bool { func validMethod(method string) bool {
return method == http.MethodDelete || method == http.MethodGet || return method == http.MethodDelete || method == http.MethodGet ||
method == http.MethodHead || method == http.MethodOptions || method == http.MethodHead || method == http.MethodOptions ||

@ -1,10 +1,12 @@
package httprouter package router
import ( import (
"net/http" "net/http"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"zero/ngin/internal/context"
) )
type mockedResponseWriter struct { type mockedResponseWriter struct {
@ -78,12 +80,12 @@ func TestPatRouter(t *testing.T) {
router := NewPatRouter() router := NewPatRouter()
err := router.Handle(test.method, "/a/:b", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err := router.Handle(test.method, "/a/:b", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
routed = true routed = true
assert.Equal(t, 1, len(Vars(r))) assert.Equal(t, 1, len(context.Vars(r)))
})) }))
assert.Nil(t, err) assert.Nil(t, err)
err = router.Handle(test.method, "/a/b/c", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err = router.Handle(test.method, "/a/b/c", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
routed = true routed = true
assert.Nil(t, Vars(r)) assert.Nil(t, context.Vars(r))
})) }))
assert.Nil(t, err) assert.Nil(t, err)
err = router.Handle(test.method, "/b/c", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err = router.Handle(test.method, "/b/c", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

@ -1,4 +1,4 @@
package httprouter package router
import ( import (
"errors" "errors"

@ -1,4 +1,4 @@
package internal package security
import ( import (
"crypto/sha256" "crypto/sha256"
@ -13,9 +13,9 @@ import (
"time" "time"
"zero/core/codec" "zero/core/codec"
"zero/core/httpx"
"zero/core/iox" "zero/core/iox"
"zero/core/logx" "zero/core/logx"
"zero/ngin/httpx"
) )
const ( const (

@ -1,4 +1,4 @@
package internal package security
import "net/http" import "net/http"

@ -1,4 +1,4 @@
package httpserver package internal
import ( import (
"crypto/tls" "crypto/tls"

@ -1,4 +1,4 @@
package httpserver package internal
import ( import (
"context" "context"

@ -1,4 +1,4 @@
package httpsecurity package internal
import ( import (
"net/http" "net/http"

@ -1,4 +1,4 @@
package httpsecurity package internal
import ( import (
"net/http" "net/http"

@ -0,0 +1,14 @@
package internal
import "net/http"
const xForwardFor = "X-Forward-For"
// Returns the peer address, supports X-Forward-For
func GetRemoteAddr(r *http.Request) string {
v := r.Header.Get(xForwardFor)
if len(v) > 0 {
return v
}
return r.RemoteAddr
}

@ -0,0 +1,19 @@
package internal
import (
"net/http"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestGetRemoteAddr(t *testing.T) {
host := "8.8.8.8"
r, err := http.NewRequest(http.MethodGet, "/", strings.NewReader(""))
assert.Nil(t, err)
r.Header.Set(xForwardFor, host)
assert.Equal(t, host, GetRemoteAddr(r))
}

@ -4,9 +4,9 @@ import (
"log" "log"
"net/http" "net/http"
"zero/core/httphandler"
"zero/core/httprouter"
"zero/core/logx" "zero/core/logx"
"zero/ngin/handler"
"zero/ngin/internal/router"
) )
type ( type (
@ -124,7 +124,7 @@ func WithPriority() RouteOption {
} }
} }
func WithRouter(router httprouter.Router) RunOption { func WithRouter(router router.Router) RunOption {
return func(engine *Engine) { return func(engine *Engine) {
engine.opts.start = func(srv *server) error { engine.opts.start = func(srv *server) error {
return srv.StartWithRouter(router) return srv.StartWithRouter(router)
@ -141,13 +141,13 @@ func WithSignature(signature SignatureConf) RouteOption {
} }
} }
func WithUnauthorizedCallback(callback httphandler.UnauthorizedCallback) RunOption { func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
return func(engine *Engine) { return func(engine *Engine) {
engine.srv.SetUnauthorizedCallback(callback) engine.srv.SetUnauthorizedCallback(callback)
} }
} }
func WithUnsignedCallback(callback httphandler.UnsignedCallback) RunOption { func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption {
return func(engine *Engine) { return func(engine *Engine) {
engine.srv.SetUnsignedCallback(callback) engine.srv.SetUnsignedCallback(callback)
} }

@ -7,15 +7,15 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"zero/core/httprouter" "zero/ngin/httpx"
"zero/core/httpx" router2 "zero/ngin/internal/router"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestWithMiddleware(t *testing.T) { func TestWithMiddleware(t *testing.T) {
m := make(map[string]string) m := make(map[string]string)
router := httprouter.NewPatRouter() router := router2.NewPatRouter()
handler := func(w http.ResponseWriter, r *http.Request) { handler := func(w http.ResponseWriter, r *http.Request) {
var v struct { var v struct {
Nickname string `form:"nickname"` Nickname string `form:"nickname"`

@ -7,11 +7,11 @@ import (
"time" "time"
"zero/core/codec" "zero/core/codec"
"zero/core/httphandler"
"zero/core/httprouter"
"zero/core/httpserver"
"zero/core/load" "zero/core/load"
"zero/core/stat" "zero/core/stat"
"zero/ngin/handler"
"zero/ngin/internal"
"zero/ngin/internal/router"
"github.com/justinas/alice" "github.com/justinas/alice"
) )
@ -27,8 +27,8 @@ type (
server struct { server struct {
conf NgConf conf NgConf
routes []featuredRoutes routes []featuredRoutes
unauthorizedCallback httphandler.UnauthorizedCallback unauthorizedCallback handler.UnauthorizedCallback
unsignedCallback httphandler.UnsignedCallback unsignedCallback handler.UnsignedCallback
middlewares []Middleware middlewares []Middleware
shedder load.Shedder shedder load.Shedder
priorityShedder load.Shedder priorityShedder load.Shedder
@ -52,43 +52,43 @@ func (s *server) AddRoutes(r featuredRoutes) {
s.routes = append(s.routes, r) s.routes = append(s.routes, r)
} }
func (s *server) SetUnauthorizedCallback(callback httphandler.UnauthorizedCallback) { func (s *server) SetUnauthorizedCallback(callback handler.UnauthorizedCallback) {
s.unauthorizedCallback = callback s.unauthorizedCallback = callback
} }
func (s *server) SetUnsignedCallback(callback httphandler.UnsignedCallback) { func (s *server) SetUnsignedCallback(callback handler.UnsignedCallback) {
s.unsignedCallback = callback s.unsignedCallback = callback
} }
func (s *server) Start() error { func (s *server) Start() error {
return s.StartWithRouter(httprouter.NewPatRouter()) return s.StartWithRouter(router.NewPatRouter())
} }
func (s *server) StartWithRouter(router httprouter.Router) error { func (s *server) StartWithRouter(router router.Router) error {
if err := s.bindRoutes(router); err != nil { if err := s.bindRoutes(router); err != nil {
return err return err
} }
return httpserver.StartHttp(s.conf.Host, s.conf.Port, router) return internal.StartHttp(s.conf.Host, s.conf.Port, router)
} }
func (s *server) appendAuthHandler(fr featuredRoutes, chain alice.Chain, func (s *server) appendAuthHandler(fr featuredRoutes, chain alice.Chain,
verifier func(alice.Chain) alice.Chain) alice.Chain { verifier func(alice.Chain) alice.Chain) alice.Chain {
if fr.jwt.enabled { if fr.jwt.enabled {
if len(fr.jwt.prevSecret) == 0 { if len(fr.jwt.prevSecret) == 0 {
chain = chain.Append(httphandler.Authorize(fr.jwt.secret, chain = chain.Append(handler.Authorize(fr.jwt.secret,
httphandler.WithUnauthorizedCallback(s.unauthorizedCallback))) handler.WithUnauthorizedCallback(s.unauthorizedCallback)))
} else { } else {
chain = chain.Append(httphandler.Authorize(fr.jwt.secret, chain = chain.Append(handler.Authorize(fr.jwt.secret,
httphandler.WithPrevSecret(fr.jwt.prevSecret), handler.WithPrevSecret(fr.jwt.prevSecret),
httphandler.WithUnauthorizedCallback(s.unauthorizedCallback))) handler.WithUnauthorizedCallback(s.unauthorizedCallback)))
} }
} }
return verifier(chain) return verifier(chain)
} }
func (s *server) bindFeaturedRoutes(router httprouter.Router, fr featuredRoutes, metrics *stat.Metrics) error { func (s *server) bindFeaturedRoutes(router router.Router, fr featuredRoutes, metrics *stat.Metrics) error {
verifier, err := s.signatureVerifier(fr.signature) verifier, err := s.signatureVerifier(fr.signature)
if err != nil { if err != nil {
return err return err
@ -103,20 +103,20 @@ func (s *server) bindFeaturedRoutes(router httprouter.Router, fr featuredRoutes,
return nil return nil
} }
func (s *server) bindRoute(fr featuredRoutes, router httprouter.Router, metrics *stat.Metrics, func (s *server) bindRoute(fr featuredRoutes, router router.Router, metrics *stat.Metrics,
route Route, verifier func(chain alice.Chain) alice.Chain) error { route Route, verifier func(chain alice.Chain) alice.Chain) error {
chain := alice.New( chain := alice.New(
httphandler.TracingHandler, handler.TracingHandler,
s.getLogHandler(), s.getLogHandler(),
httphandler.MaxConns(s.conf.MaxConns), handler.MaxConns(s.conf.MaxConns),
httphandler.BreakerHandler(route.Method, route.Path, metrics), handler.BreakerHandler(route.Method, route.Path, metrics),
httphandler.SheddingHandler(s.getShedder(fr.priority), metrics), handler.SheddingHandler(s.getShedder(fr.priority), metrics),
httphandler.TimeoutHandler(time.Duration(s.conf.Timeout)*time.Millisecond), handler.TimeoutHandler(time.Duration(s.conf.Timeout)*time.Millisecond),
httphandler.RecoverHandler, handler.RecoverHandler,
httphandler.MetricHandler(metrics), handler.MetricHandler(metrics),
httphandler.PromMetricHandler(route.Path), handler.PromMetricHandler(route.Path),
httphandler.MaxBytesHandler(s.conf.MaxBytes), handler.MaxBytesHandler(s.conf.MaxBytes),
httphandler.GunzipHandler, handler.GunzipHandler,
) )
chain = s.appendAuthHandler(fr, chain, verifier) chain = s.appendAuthHandler(fr, chain, verifier)
@ -128,7 +128,7 @@ func (s *server) bindRoute(fr featuredRoutes, router httprouter.Router, metrics
return router.Handle(route.Method, route.Path, handle) return router.Handle(route.Method, route.Path, handle)
} }
func (s *server) bindRoutes(router httprouter.Router) error { func (s *server) bindRoutes(router router.Router) error {
metrics := s.createMetrics() metrics := s.createMetrics()
for _, fr := range s.routes { for _, fr := range s.routes {
@ -154,9 +154,9 @@ func (s *server) createMetrics() *stat.Metrics {
func (s *server) getLogHandler() func(http.Handler) http.Handler { func (s *server) getLogHandler() func(http.Handler) http.Handler {
if s.conf.Verbose { if s.conf.Verbose {
return httphandler.DetailedLogHandler return handler.DetailedLogHandler
} else { } else {
return httphandler.LogHandler return handler.LogHandler
} }
} }
@ -198,10 +198,10 @@ func (s *server) signatureVerifier(signature signatureSetting) (func(chain alice
return func(chain alice.Chain) alice.Chain { return func(chain alice.Chain) alice.Chain {
if s.unsignedCallback != nil { if s.unsignedCallback != nil {
return chain.Append(httphandler.ContentSecurityHandler( return chain.Append(handler.ContentSecurityHandler(
decrypters, signature.Expiry, signature.Strict, s.unsignedCallback)) decrypters, signature.Expiry, signature.Strict, s.unsignedCallback))
} else { } else {
return chain.Append(httphandler.ContentSecurityHandler( return chain.Append(handler.ContentSecurityHandler(
decrypters, signature.Expiry, signature.Strict)) decrypters, signature.Expiry, signature.Strict))
} }
}, nil }, nil

Loading…
Cancel
Save