diff --git a/.codecov.yml b/.codecov.yml index b0358f36..4a78255b 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -6,3 +6,4 @@ ignore: - "tools" - "**/mock" - "**/*_mock.go" + - "**/*test" diff --git a/core/logc/logs_test.go b/core/logc/logs_test.go index fd1825cd..3ee98642 100644 --- a/core/logc/logs_test.go +++ b/core/logc/logs_test.go @@ -1,7 +1,6 @@ package logc import ( - "bytes" "context" "encoding/json" "fmt" @@ -11,14 +10,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/zeromicro/go-zero/core/logx" + "github.com/zeromicro/go-zero/core/logx/logtest" ) func TestAddGlobalFields(t *testing.T) { - var buf bytes.Buffer - writer := logx.NewWriter(&buf) - old := logx.Reset() - logx.SetWriter(writer) - defer logx.SetWriter(old) + buf := logtest.NewCollector(t) Info(context.Background(), "hello") buf.Reset() @@ -34,155 +30,90 @@ func TestAddGlobalFields(t *testing.T) { } func TestAlert(t *testing.T) { - var buf strings.Builder - writer := logx.NewWriter(&buf) - old := logx.Reset() - logx.SetWriter(writer) - defer logx.SetWriter(old) - + buf := logtest.NewCollector(t) Alert(context.Background(), "foo") assert.True(t, strings.Contains(buf.String(), "foo"), buf.String()) } func TestError(t *testing.T) { - var buf strings.Builder - writer := logx.NewWriter(&buf) - old := logx.Reset() - logx.SetWriter(writer) - defer logx.SetWriter(old) - + buf := logtest.NewCollector(t) file, line := getFileLine() Error(context.Background(), "foo") assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1))) } func TestErrorf(t *testing.T) { - var buf strings.Builder - writer := logx.NewWriter(&buf) - old := logx.Reset() - logx.SetWriter(writer) - defer logx.SetWriter(old) - + buf := logtest.NewCollector(t) file, line := getFileLine() Errorf(context.Background(), "foo %s", "bar") assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1))) } func TestErrorv(t *testing.T) { - var buf strings.Builder - writer := logx.NewWriter(&buf) - old := logx.Reset() - logx.SetWriter(writer) - defer logx.SetWriter(old) - + buf := logtest.NewCollector(t) file, line := getFileLine() Errorv(context.Background(), "foo") assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1))) } func TestErrorw(t *testing.T) { - var buf strings.Builder - writer := logx.NewWriter(&buf) - old := logx.Reset() - logx.SetWriter(writer) - defer logx.SetWriter(old) - + buf := logtest.NewCollector(t) file, line := getFileLine() Errorw(context.Background(), "foo", Field("a", "b")) assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1))) } func TestInfo(t *testing.T) { - var buf strings.Builder - writer := logx.NewWriter(&buf) - old := logx.Reset() - logx.SetWriter(writer) - defer logx.SetWriter(old) - + buf := logtest.NewCollector(t) file, line := getFileLine() Info(context.Background(), "foo") assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1))) } func TestInfof(t *testing.T) { - var buf strings.Builder - writer := logx.NewWriter(&buf) - old := logx.Reset() - logx.SetWriter(writer) - defer logx.SetWriter(old) - + buf := logtest.NewCollector(t) file, line := getFileLine() Infof(context.Background(), "foo %s", "bar") assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1))) } func TestInfov(t *testing.T) { - var buf strings.Builder - writer := logx.NewWriter(&buf) - old := logx.Reset() - logx.SetWriter(writer) - defer logx.SetWriter(old) - + buf := logtest.NewCollector(t) file, line := getFileLine() Infov(context.Background(), "foo") assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1))) } func TestInfow(t *testing.T) { - var buf strings.Builder - writer := logx.NewWriter(&buf) - old := logx.Reset() - logx.SetWriter(writer) - defer logx.SetWriter(old) - + buf := logtest.NewCollector(t) file, line := getFileLine() Infow(context.Background(), "foo", Field("a", "b")) assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1))) } func TestDebug(t *testing.T) { - var buf strings.Builder - writer := logx.NewWriter(&buf) - old := logx.Reset() - logx.SetWriter(writer) - defer logx.SetWriter(old) - + buf := logtest.NewCollector(t) file, line := getFileLine() Debug(context.Background(), "foo") assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1))) } func TestDebugf(t *testing.T) { - var buf strings.Builder - writer := logx.NewWriter(&buf) - old := logx.Reset() - logx.SetWriter(writer) - defer logx.SetWriter(old) - + buf := logtest.NewCollector(t) file, line := getFileLine() Debugf(context.Background(), "foo %s", "bar") assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1))) } func TestDebugv(t *testing.T) { - var buf strings.Builder - writer := logx.NewWriter(&buf) - old := logx.Reset() - logx.SetWriter(writer) - defer logx.SetWriter(old) - + buf := logtest.NewCollector(t) file, line := getFileLine() Debugv(context.Background(), "foo") assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1))) } func TestDebugw(t *testing.T) { - var buf strings.Builder - writer := logx.NewWriter(&buf) - old := logx.Reset() - logx.SetWriter(writer) - defer logx.SetWriter(old) - + buf := logtest.NewCollector(t) file, line := getFileLine() Debugw(context.Background(), "foo", Field("a", "b")) assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1))) @@ -204,48 +135,28 @@ func TestMisc(t *testing.T) { } func TestSlow(t *testing.T) { - var buf strings.Builder - writer := logx.NewWriter(&buf) - old := logx.Reset() - logx.SetWriter(writer) - defer logx.SetWriter(old) - + buf := logtest.NewCollector(t) file, line := getFileLine() Slow(context.Background(), "foo") assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String()) } func TestSlowf(t *testing.T) { - var buf strings.Builder - writer := logx.NewWriter(&buf) - old := logx.Reset() - logx.SetWriter(writer) - defer logx.SetWriter(old) - + buf := logtest.NewCollector(t) file, line := getFileLine() Slowf(context.Background(), "foo %s", "bar") assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String()) } func TestSlowv(t *testing.T) { - var buf strings.Builder - writer := logx.NewWriter(&buf) - old := logx.Reset() - logx.SetWriter(writer) - defer logx.SetWriter(old) - + buf := logtest.NewCollector(t) file, line := getFileLine() Slowv(context.Background(), "foo") assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String()) } func TestSloww(t *testing.T) { - var buf strings.Builder - writer := logx.NewWriter(&buf) - old := logx.Reset() - logx.SetWriter(writer) - defer logx.SetWriter(old) - + buf := logtest.NewCollector(t) file, line := getFileLine() Sloww(context.Background(), "foo", Field("a", "b")) assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String()) diff --git a/core/logx/logtest/logtest.go b/core/logx/logtest/logtest.go new file mode 100644 index 00000000..53c2db43 --- /dev/null +++ b/core/logx/logtest/logtest.go @@ -0,0 +1,78 @@ +package logtest + +import ( + "bytes" + "encoding/json" + "io" + "testing" + + "github.com/zeromicro/go-zero/core/logx" +) + +type Buffer struct { + buf *bytes.Buffer + t *testing.T +} + +func Discard(t *testing.T) { + prev := logx.Reset() + logx.SetWriter(logx.NewWriter(io.Discard)) + + t.Cleanup(func() { + logx.SetWriter(prev) + }) +} + +func NewCollector(t *testing.T) *Buffer { + var buf bytes.Buffer + writer := logx.NewWriter(&buf) + prev := logx.Reset() + logx.SetWriter(writer) + + t.Cleanup(func() { + logx.SetWriter(prev) + }) + + return &Buffer{ + buf: &buf, + t: t, + } +} + +func (b *Buffer) Bytes() []byte { + return b.buf.Bytes() +} + +func (b *Buffer) Content() string { + var m map[string]interface{} + if err := json.Unmarshal(b.buf.Bytes(), &m); err != nil { + b.t.Error(err) + return "" + } + + content, ok := m["content"] + if !ok { + return "" + } + + switch val := content.(type) { + case string: + return val + default: + bs, err := json.Marshal(content) + if err != nil { + b.t.Error(err) + return "" + } + + return string(bs) + } +} + +func (b *Buffer) Reset() { + b.buf.Reset() +} + +func (b *Buffer) String() string { + return b.buf.String() +} diff --git a/core/logx/logtest/logtest_test.go b/core/logx/logtest/logtest_test.go new file mode 100644 index 00000000..1a61e07a --- /dev/null +++ b/core/logx/logtest/logtest_test.go @@ -0,0 +1,22 @@ +package logtest + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/zeromicro/go-zero/core/logx" +) + +func TestCollector(t *testing.T) { + const input = "hello" + c := NewCollector(t) + logx.Info(input) + assert.Equal(t, input, c.Content()) + assert.Contains(t, c.String(), input) +} + +func TestDiscard(t *testing.T) { + const input = "hello" + Discard(t) + logx.Info(input) +} diff --git a/core/proc/goroutines_test.go b/core/proc/goroutines_test.go index 9464fba9..267264c3 100644 --- a/core/proc/goroutines_test.go +++ b/core/proc/goroutines_test.go @@ -5,19 +5,11 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/zeromicro/go-zero/core/logx" + "github.com/zeromicro/go-zero/core/logx/logtest" ) func TestDumpGoroutines(t *testing.T) { - var buf strings.Builder - w := logx.NewWriter(&buf) - o := logx.Reset() - logx.SetWriter(w) - defer func() { - logx.Reset() - logx.SetWriter(o) - }() - + buf := logtest.NewCollector(t) dumpGoroutines() assert.True(t, strings.Contains(buf.String(), ".dump")) } diff --git a/core/proc/profile_test.go b/core/proc/profile_test.go index eb89ceee..f82caac4 100644 --- a/core/proc/profile_test.go +++ b/core/proc/profile_test.go @@ -5,25 +5,16 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/zeromicro/go-zero/core/logx" + "github.com/zeromicro/go-zero/core/logx/logtest" ) func TestProfile(t *testing.T) { - var buf strings.Builder - w := logx.NewWriter(&buf) - o := logx.Reset() - logx.SetWriter(w) - - defer func() { - logx.Reset() - logx.SetWriter(o) - }() - + c := logtest.NewCollector(t) profiler := StartProfile() // start again should not work assert.NotNil(t, StartProfile()) profiler.Stop() // stop twice profiler.Stop() - assert.True(t, strings.Contains(buf.String(), ".pprof")) + assert.True(t, strings.Contains(c.String(), ".pprof")) } diff --git a/core/stat/usage_test.go b/core/stat/usage_test.go index 6cb5e19b..9fba767b 100644 --- a/core/stat/usage_test.go +++ b/core/stat/usage_test.go @@ -1,12 +1,11 @@ package stat import ( - "bytes" "strings" "testing" "github.com/stretchr/testify/assert" - "github.com/zeromicro/go-zero/core/logx" + "github.com/zeromicro/go-zero/core/logx/logtest" ) func TestBToMb(t *testing.T) { @@ -41,15 +40,11 @@ func TestBToMb(t *testing.T) { } func TestPrintUsage(t *testing.T) { - var buf bytes.Buffer - writer := logx.NewWriter(&buf) - old := logx.Reset() - logx.SetWriter(writer) - defer logx.SetWriter(old) + c := logtest.NewCollector(t) printUsage() - output := buf.String() + output := c.String() assert.Contains(t, output, "CPU:") assert.Contains(t, output, "MEMORY:") assert.Contains(t, output, "Alloc=") diff --git a/core/stores/mon/collection_test.go b/core/stores/mon/collection_test.go index d3c14a49..964ebb23 100644 --- a/core/stores/mon/collection_test.go +++ b/core/stores/mon/collection_test.go @@ -3,12 +3,11 @@ package mon import ( "context" "errors" - "strings" "testing" "github.com/stretchr/testify/assert" "github.com/zeromicro/go-zero/core/breaker" - "github.com/zeromicro/go-zero/core/logx" + "github.com/zeromicro/go-zero/core/logx/logtest" "github.com/zeromicro/go-zero/core/stringx" "github.com/zeromicro/go-zero/core/timex" "go.mongodb.org/mongo-driver/bson" @@ -573,15 +572,7 @@ func TestDecoratedCollection_LogDuration(t *testing.T) { brk: breaker.NewBreaker(), } - var buf strings.Builder - w := logx.NewWriter(&buf) - o := logx.Reset() - logx.SetWriter(w) - - defer func() { - logx.Reset() - logx.SetWriter(o) - }() + buf := logtest.NewCollector(t) buf.Reset() c.logDuration(context.Background(), "foo", timex.Now(), nil, "bar") diff --git a/core/stores/mon/util_test.go b/core/stores/mon/util_test.go index b042f854..3ab7cf2b 100644 --- a/core/stores/mon/util_test.go +++ b/core/stores/mon/util_test.go @@ -3,12 +3,11 @@ package mon import ( "context" "errors" - "strings" "testing" "time" "github.com/stretchr/testify/assert" - "github.com/zeromicro/go-zero/core/logx" + "github.com/zeromicro/go-zero/core/logx/logtest" ) func TestFormatAddrs(t *testing.T) { @@ -40,15 +39,7 @@ func TestFormatAddrs(t *testing.T) { } func Test_logDuration(t *testing.T) { - var buf strings.Builder - w := logx.NewWriter(&buf) - o := logx.Reset() - logx.SetWriter(w) - - defer func() { - logx.Reset() - logx.SetWriter(o) - }() + buf := logtest.NewCollector(t) buf.Reset() logDuration(context.Background(), "foo", "bar", time.Millisecond, nil) diff --git a/core/stores/redis/hook_test.go b/core/stores/redis/hook_test.go index b055b8f8..1e81a211 100644 --- a/core/stores/redis/hook_test.go +++ b/core/stores/redis/hook_test.go @@ -9,7 +9,7 @@ import ( red "github.com/go-redis/redis/v8" "github.com/stretchr/testify/assert" - "github.com/zeromicro/go-zero/core/logx" + "github.com/zeromicro/go-zero/core/logx/logtest" ztrace "github.com/zeromicro/go-zero/core/trace" tracesdk "go.opentelemetry.io/otel/trace" ) @@ -47,8 +47,7 @@ func TestHookProcessCase2(t *testing.T) { }) defer ztrace.StopAgent() - w, restore := injectLog() - defer restore() + w := logtest.NewCollector(t) ctx, err := durationHook.BeforeProcess(context.Background(), red.NewCmd(context.Background())) if err != nil { @@ -115,8 +114,7 @@ func TestHookProcessPipelineCase2(t *testing.T) { }) defer ztrace.StopAgent() - w, restore := injectLog() - defer restore() + w := logtest.NewCollector(t) ctx, err := durationHook.BeforeProcessPipeline(context.Background(), []red.Cmder{ red.NewCmd(context.Background()), @@ -135,8 +133,7 @@ func TestHookProcessPipelineCase2(t *testing.T) { } func TestHookProcessPipelineCase3(t *testing.T) { - w, restore := injectLog() - defer restore() + w := logtest.NewCollector(t) assert.Nil(t, durationHook.AfterProcessPipeline(context.Background(), []red.Cmder{ red.NewCmd(context.Background()), @@ -145,8 +142,7 @@ func TestHookProcessPipelineCase3(t *testing.T) { } func TestHookProcessPipelineCase4(t *testing.T) { - w, restore := injectLog() - defer restore() + w := logtest.NewCollector(t) ctx := context.WithValue(context.Background(), startTimeKey, "foo") assert.Nil(t, durationHook.AfterProcessPipeline(ctx, []red.Cmder{ @@ -169,8 +165,7 @@ func TestHookProcessPipelineCase5(t *testing.T) { } func TestLogDuration(t *testing.T) { - w, restore := injectLog() - defer restore() + w := logtest.NewCollector(t) logDuration(context.Background(), []red.Cmder{ red.NewCmd(context.Background(), "get", "foo"), @@ -183,15 +178,3 @@ func TestLogDuration(t *testing.T) { }, 1*time.Second) assert.True(t, strings.Contains(w.String(), `get foo\nset bar 0`)) } - -func injectLog() (r *strings.Builder, restore func()) { - var buf strings.Builder - w := logx.NewWriter(&buf) - o := logx.Reset() - logx.SetWriter(w) - - return &buf, func() { - logx.Reset() - logx.SetWriter(o) - } -} diff --git a/core/trace/tracetest/tracetest.go b/core/trace/tracetest/tracetest.go index e1c496bb..fb8b17d9 100644 --- a/core/trace/tracetest/tracetest.go +++ b/core/trace/tracetest/tracetest.go @@ -16,5 +16,6 @@ func NewInMemoryExporter(t *testing.T) *tracetest.InMemoryExporter { me.Reset() }) otel.SetTracerProvider(trace.NewTracerProvider(trace.WithSyncer(me))) + return me } diff --git a/rest/handler/breakerhandler_test.go b/rest/handler/breakerhandler_test.go index bda8fef1..d4f6e4c8 100644 --- a/rest/handler/breakerhandler_test.go +++ b/rest/handler/breakerhandler_test.go @@ -7,12 +7,10 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/stat" ) func init() { - logx.Disable() stat.SetReporter(nil) } diff --git a/rest/handler/contentsecurityhandler_test.go b/rest/handler/contentsecurityhandler_test.go index 80b755bd..ada8b194 100644 --- a/rest/handler/contentsecurityhandler_test.go +++ b/rest/handler/contentsecurityhandler_test.go @@ -62,10 +62,6 @@ type requestSettings struct { signature string } -func init() { - log.SetOutput(io.Discard) -} - func TestContentSecurityHandler(t *testing.T) { tests := []struct { method string diff --git a/rest/handler/cryptionhandler_test.go b/rest/handler/cryptionhandler_test.go index d3fcfe09..b24a5850 100644 --- a/rest/handler/cryptionhandler_test.go +++ b/rest/handler/cryptionhandler_test.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/base64" "io" - "log" "math/rand" "net/http" "net/http/httptest" @@ -21,10 +20,6 @@ const ( var aesKey = []byte(`PdSgVkYp3s6v9y$B&E)H+MbQeThWmZq4`) -func init() { - log.SetOutput(io.Discard) -} - func TestCryptionHandlerGet(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/any", http.NoBody) handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/rest/handler/loghandler_test.go b/rest/handler/loghandler_test.go index 65cd0b88..2a562576 100644 --- a/rest/handler/loghandler_test.go +++ b/rest/handler/loghandler_test.go @@ -3,7 +3,6 @@ package handler import ( "bytes" "io" - "log" "net/http" "net/http/httptest" "testing" @@ -14,10 +13,6 @@ import ( "github.com/zeromicro/go-zero/rest/internal/response" ) -func init() { - log.SetOutput(io.Discard) -} - func TestLogHandler(t *testing.T) { handlers := []func(handler http.Handler) http.Handler{ LogHandler, diff --git a/rest/handler/maxconnshandler_test.go b/rest/handler/maxconnshandler_test.go index 2e483436..0e64c0c6 100644 --- a/rest/handler/maxconnshandler_test.go +++ b/rest/handler/maxconnshandler_test.go @@ -1,8 +1,6 @@ package handler import ( - "io" - "log" "net/http" "net/http/httptest" "sync" @@ -14,10 +12,6 @@ import ( const conns = 4 -func init() { - log.SetOutput(io.Discard) -} - func TestMaxConnsHandler(t *testing.T) { var waitGroup sync.WaitGroup waitGroup.Add(conns) diff --git a/rest/handler/recoverhandler_test.go b/rest/handler/recoverhandler_test.go index 51189b5a..ef016532 100644 --- a/rest/handler/recoverhandler_test.go +++ b/rest/handler/recoverhandler_test.go @@ -1,8 +1,6 @@ package handler import ( - "io" - "log" "net/http" "net/http/httptest" "testing" @@ -10,10 +8,6 @@ import ( "github.com/stretchr/testify/assert" ) -func init() { - log.SetOutput(io.Discard) -} - func TestWithPanic(t *testing.T) { handler := RecoverHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { panic("whatever") diff --git a/rest/handler/sheddinghandler_test.go b/rest/handler/sheddinghandler_test.go index f3dbd209..1b05d35d 100644 --- a/rest/handler/sheddinghandler_test.go +++ b/rest/handler/sheddinghandler_test.go @@ -1,8 +1,6 @@ package handler import ( - "io" - "log" "net/http" "net/http/httptest" "testing" @@ -12,10 +10,6 @@ import ( "github.com/zeromicro/go-zero/core/stat" ) -func init() { - log.SetOutput(io.Discard) -} - func TestSheddingHandlerAccept(t *testing.T) { metrics := stat.NewMetrics("unit-test") shedder := mockShedder{ diff --git a/rest/handler/timeouthandler.go b/rest/handler/timeouthandler.go index 40340a2e..ec34180a 100644 --- a/rest/handler/timeouthandler.go +++ b/rest/handler/timeouthandler.go @@ -31,14 +31,14 @@ const ( // Notice: even if canceled in server side, 499 will be logged as well. func TimeoutHandler(duration time.Duration) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { - if duration > 0 { - return &timeoutHandler{ - handler: next, - dt: duration, - } + if duration <= 0 { + return next } - return next + return &timeoutHandler{ + handler: next, + dt: duration, + } } } @@ -207,9 +207,11 @@ func relevantCaller() runtime.Frame { if !strings.HasPrefix(frame.Function, "net/http.") { return frame } + if !more { break } } + return frame } diff --git a/rest/handler/timeouthandler_test.go b/rest/handler/timeouthandler_test.go index 23eb88c6..3fd8bdd2 100644 --- a/rest/handler/timeouthandler_test.go +++ b/rest/handler/timeouthandler_test.go @@ -2,21 +2,16 @@ package handler import ( "context" - "io" - "log" "net/http" "net/http/httptest" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/zeromicro/go-zero/core/logx/logtest" "github.com/zeromicro/go-zero/rest/internal/response" ) -func init() { - log.SetOutput(io.Discard) -} - func TestTimeout(t *testing.T) { timeoutHandler := TimeoutHandler(time.Millisecond) handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -45,7 +40,12 @@ func TestWithTimeoutTimedout(t *testing.T) { timeoutHandler := TimeoutHandler(time.Millisecond) handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(time.Millisecond * 10) - w.Write([]byte(`foo`)) + _, err := w.Write([]byte(`foo`)) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) })) @@ -96,7 +96,12 @@ func TestTimeoutWebsocket(t *testing.T) { func TestTimeoutWroteHeaderTwice(t *testing.T) { timeoutHandler := TimeoutHandler(time.Minute) handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(`hello`)) + _, err := w.Write([]byte(`hello`)) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Set("foo", "bar") w.WriteHeader(http.StatusOK) })) @@ -145,7 +150,7 @@ func TestTimeoutHijack(t *testing.T) { } assert.NotPanics(t, func() { - writer.Hijack() + _, _, _ = writer.Hijack() }) writer = &timeoutWriter{ @@ -155,7 +160,7 @@ func TestTimeoutHijack(t *testing.T) { } assert.NotPanics(t, func() { - writer.Hijack() + _, _, _ = writer.Hijack() }) } @@ -165,7 +170,7 @@ func TestTimeoutPusher(t *testing.T) { } assert.Panics(t, func() { - handler.Push("any", nil) + _ = handler.Push("any", nil) }) handler = &timeoutWriter{ @@ -174,20 +179,44 @@ func TestTimeoutPusher(t *testing.T) { assert.Equal(t, http.ErrNotSupported, handler.Push("any", nil)) } +func TestTimeoutWriter_Hijack(t *testing.T) { + writer := &timeoutWriter{ + w: httptest.NewRecorder(), + h: make(http.Header), + req: httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody), + } + _, _, err := writer.Hijack() + assert.Error(t, err) +} + +func TestTimeoutWroteTwice(t *testing.T) { + c := logtest.NewCollector(t) + writer := &timeoutWriter{ + w: &response.WithCodeResponseWriter{ + Writer: httptest.NewRecorder(), + }, + h: make(http.Header), + req: httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody), + } + writer.writeHeaderLocked(http.StatusOK) + writer.writeHeaderLocked(http.StatusOK) + assert.Contains(t, c.String(), "superfluous response.WriteHeader call") +} + type mockedPusher struct{} func (m mockedPusher) Header() http.Header { panic("implement me") } -func (m mockedPusher) Write(bytes []byte) (int, error) { +func (m mockedPusher) Write(_ []byte) (int, error) { panic("implement me") } -func (m mockedPusher) WriteHeader(statusCode int) { +func (m mockedPusher) WriteHeader(_ int) { panic("implement me") } -func (m mockedPusher) Push(target string, opts *http.PushOptions) error { +func (m mockedPusher) Push(_ string, _ *http.PushOptions) error { panic("implement me") } diff --git a/rest/internal/log_test.go b/rest/internal/log_test.go index f4bcc799..79e08af6 100644 --- a/rest/internal/log_test.go +++ b/rest/internal/log_test.go @@ -8,7 +8,7 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/zeromicro/go-zero/core/logx" + "github.com/zeromicro/go-zero/core/logx/logtest" ) func TestInfo(t *testing.T) { @@ -25,20 +25,11 @@ func TestInfo(t *testing.T) { } func TestError(t *testing.T) { - var buf strings.Builder - w := logx.NewWriter(&buf) - o := logx.Reset() - logx.SetWriter(w) - - defer func() { - logx.Reset() - logx.SetWriter(o) - }() - + c := logtest.NewCollector(t) req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody) Error(req, "first") Errorf(req, "second %s", "third") - val := buf.String() + val := c.String() assert.True(t, strings.Contains(val, "first")) assert.True(t, strings.Contains(val, "second")) assert.True(t, strings.Contains(val, "third")) diff --git a/rest/server_test.go b/rest/server_test.go index bc3f6577..3a4bdded 100644 --- a/rest/server_test.go +++ b/rest/server_test.go @@ -14,7 +14,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/zeromicro/go-zero/core/conf" - "github.com/zeromicro/go-zero/core/logx" + "github.com/zeromicro/go-zero/core/logx/logtest" "github.com/zeromicro/go-zero/rest/chain" "github.com/zeromicro/go-zero/rest/httpx" "github.com/zeromicro/go-zero/rest/internal/cors" @@ -22,9 +22,7 @@ import ( ) func TestNewServer(t *testing.T) { - writer := logx.Reset() - defer logx.SetWriter(writer) - logx.SetWriter(logx.NewWriter(io.Discard)) + logtest.Discard(t) const configYaml = ` Name: foo diff --git a/zrpc/internal/rpclogger_test.go b/zrpc/internal/rpclogger_test.go index 1c504343..2786fd3c 100644 --- a/zrpc/internal/rpclogger_test.go +++ b/zrpc/internal/rpclogger_test.go @@ -1,121 +1,96 @@ package internal import ( - "strings" "testing" "github.com/stretchr/testify/assert" - "github.com/zeromicro/go-zero/core/logx" + "github.com/zeromicro/go-zero/core/logx/logtest" ) const content = "foo" func TestLoggerError(t *testing.T) { - w, restore := injectLog() - defer restore() - + c := logtest.NewCollector(t) logger := new(Logger) logger.Error(content) - assert.Contains(t, w.String(), content) + assert.Contains(t, c.String(), content) } func TestLoggerErrorf(t *testing.T) { - w, restore := injectLog() - defer restore() - + c := logtest.NewCollector(t) logger := new(Logger) logger.Errorf(content) - assert.Contains(t, w.String(), content) + assert.Contains(t, c.String(), content) } func TestLoggerErrorln(t *testing.T) { - w, restore := injectLog() - defer restore() - + c := logtest.NewCollector(t) logger := new(Logger) logger.Errorln(content) - assert.Contains(t, w.String(), content) + assert.Contains(t, c.String(), content) } func TestLoggerFatal(t *testing.T) { - w, restore := injectLog() - defer restore() - + c := logtest.NewCollector(t) logger := new(Logger) logger.Fatal(content) - assert.Contains(t, w.String(), content) + assert.Contains(t, c.String(), content) } func TestLoggerFatalf(t *testing.T) { - w, restore := injectLog() - defer restore() - + c := logtest.NewCollector(t) logger := new(Logger) logger.Fatalf(content) - assert.Contains(t, w.String(), content) + assert.Contains(t, c.String(), content) } func TestLoggerFatalln(t *testing.T) { - w, restore := injectLog() - defer restore() - + c := logtest.NewCollector(t) logger := new(Logger) logger.Fatalln(content) - assert.Contains(t, w.String(), content) + assert.Contains(t, c.String(), content) } func TestLoggerInfo(t *testing.T) { - w, restore := injectLog() - defer restore() - + c := logtest.NewCollector(t) logger := new(Logger) logger.Info(content) - assert.Empty(t, w.String()) + assert.Empty(t, c.String()) } func TestLoggerInfof(t *testing.T) { - w, restore := injectLog() - defer restore() - + c := logtest.NewCollector(t) logger := new(Logger) logger.Infof(content) - assert.Empty(t, w.String()) + assert.Empty(t, c.String()) } func TestLoggerWarning(t *testing.T) { - w, restore := injectLog() - defer restore() - + c := logtest.NewCollector(t) logger := new(Logger) logger.Warning(content) - assert.Empty(t, w.String()) + assert.Empty(t, c.String()) } func TestLoggerInfoln(t *testing.T) { - w, restore := injectLog() - defer restore() - + c := logtest.NewCollector(t) logger := new(Logger) logger.Infoln(content) - assert.Empty(t, w.String()) + assert.Empty(t, c.String()) } func TestLoggerWarningf(t *testing.T) { - w, restore := injectLog() - defer restore() - + c := logtest.NewCollector(t) logger := new(Logger) logger.Warningf(content) - assert.Empty(t, w.String()) + assert.Empty(t, c.String()) } func TestLoggerWarningln(t *testing.T) { - w, restore := injectLog() - defer restore() - + c := logtest.NewCollector(t) logger := new(Logger) logger.Warningln(content) - assert.Empty(t, w.String()) + assert.Empty(t, c.String()) } func TestLogger_V(t *testing.T) { @@ -125,15 +100,3 @@ func TestLogger_V(t *testing.T) { // grpclog.infoLog assert.False(t, logger.V(0)) } - -func injectLog() (r *strings.Builder, restore func()) { - var buf strings.Builder - w := logx.NewWriter(&buf) - o := logx.Reset() - logx.SetWriter(w) - - return &buf, func() { - logx.Reset() - logx.SetWriter(o) - } -} diff --git a/zrpc/internal/rpcserver_test.go b/zrpc/internal/rpcserver_test.go index 90f4bc9c..99b48eb4 100644 --- a/zrpc/internal/rpcserver_test.go +++ b/zrpc/internal/rpcserver_test.go @@ -4,6 +4,7 @@ import ( "context" "sync" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/zeromicro/go-zero/core/proc" @@ -40,6 +41,8 @@ func TestRpcServer(t *testing.T) { }() wg.Wait() + time.Sleep(100 * time.Millisecond) + lock.Lock() grpcServer.GracefulStop() lock.Unlock()