diff --git a/core/cmdline/input_test.go b/core/cmdline/input_test.go index 0a390a6e..c93511fc 100644 --- a/core/cmdline/input_test.go +++ b/core/cmdline/input_test.go @@ -8,20 +8,14 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/core/iox" "github.com/tal-tech/go-zero/core/lang" ) func TestEnterToContinue(t *testing.T) { - r, w, err := os.Pipe() + restore, err := iox.RedirectInOut() assert.Nil(t, err) - ow := os.Stdout - os.Stdout = w - or := os.Stdin - os.Stdin = r - defer func() { - os.Stdin = or - os.Stdout = ow - }() + defer restore() var wg sync.WaitGroup wg.Add(2) diff --git a/core/iox/pipe.go b/core/iox/pipe.go new file mode 100644 index 00000000..be36479c --- /dev/null +++ b/core/iox/pipe.go @@ -0,0 +1,23 @@ +package iox + +import "os" + +// RedirectInOut redirects stdin to r, stdout to w, and callers need to call restore afterwards. +func RedirectInOut() (restore func(), err error) { + var r, w *os.File + r, w, err = os.Pipe() + if err != nil { + return + } + + ow := os.Stdout + os.Stdout = w + or := os.Stdin + os.Stdin = r + restore = func() { + os.Stdin = or + os.Stdout = ow + } + + return +} diff --git a/core/iox/pipe_test.go b/core/iox/pipe_test.go new file mode 100644 index 00000000..2f4479b9 --- /dev/null +++ b/core/iox/pipe_test.go @@ -0,0 +1,13 @@ +package iox + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRedirectInOut(t *testing.T) { + restore, err := RedirectInOut() + assert.Nil(t, err) + defer restore() +} diff --git a/core/logx/tracelogger_test.go b/core/logx/tracelogger_test.go index f1e76000..2c2640f9 100644 --- a/core/logx/tracelogger_test.go +++ b/core/logx/tracelogger_test.go @@ -2,8 +2,10 @@ package logx import ( "context" + "log" "strings" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/tal-tech/go-zero/core/trace/tracespec" @@ -24,6 +26,65 @@ func TestTraceLog(t *testing.T) { assert.True(t, strings.Contains(buf.String(), mockSpanId)) } +func TestTraceError(t *testing.T) { + var buf strings.Builder + ctx := context.WithValue(context.Background(), tracespec.TracingKey, mock) + l := WithContext(ctx).(*traceLogger) + SetLevel(InfoLevel) + errorLog = newLogWriter(log.New(&buf, "", flags)) + l.WithDuration(time.Second).Error(testlog) + assert.True(t, strings.Contains(buf.String(), mockTraceId)) + assert.True(t, strings.Contains(buf.String(), mockSpanId)) + buf.Reset() + l.WithDuration(time.Second).Errorf(testlog) + assert.True(t, strings.Contains(buf.String(), mockTraceId)) + assert.True(t, strings.Contains(buf.String(), mockSpanId)) +} + +func TestTraceInfo(t *testing.T) { + var buf strings.Builder + ctx := context.WithValue(context.Background(), tracespec.TracingKey, mock) + l := WithContext(ctx).(*traceLogger) + SetLevel(InfoLevel) + infoLog = newLogWriter(log.New(&buf, "", flags)) + l.WithDuration(time.Second).Info(testlog) + assert.True(t, strings.Contains(buf.String(), mockTraceId)) + assert.True(t, strings.Contains(buf.String(), mockSpanId)) + buf.Reset() + l.WithDuration(time.Second).Infof(testlog) + assert.True(t, strings.Contains(buf.String(), mockTraceId)) + assert.True(t, strings.Contains(buf.String(), mockSpanId)) +} + +func TestTraceSlow(t *testing.T) { + var buf strings.Builder + ctx := context.WithValue(context.Background(), tracespec.TracingKey, mock) + l := WithContext(ctx).(*traceLogger) + SetLevel(InfoLevel) + slowLog = newLogWriter(log.New(&buf, "", flags)) + l.WithDuration(time.Second).Slow(testlog) + assert.True(t, strings.Contains(buf.String(), mockTraceId)) + assert.True(t, strings.Contains(buf.String(), mockSpanId)) + buf.Reset() + l.WithDuration(time.Second).Slowf(testlog) + assert.True(t, strings.Contains(buf.String(), mockTraceId)) + assert.True(t, strings.Contains(buf.String(), mockSpanId)) +} + +func TestTraceWithoutContext(t *testing.T) { + var buf strings.Builder + l := WithContext(context.Background()).(*traceLogger) + SetLevel(InfoLevel) + infoLog = newLogWriter(log.New(&buf, "", flags)) + l.WithDuration(time.Second).Info(testlog) + assert.False(t, strings.Contains(buf.String(), mockTraceId)) + assert.False(t, strings.Contains(buf.String(), mockSpanId)) + buf.Reset() + l.WithDuration(time.Second).Infof(testlog) + assert.False(t, strings.Contains(buf.String(), mockTraceId)) + assert.False(t, strings.Contains(buf.String(), mockSpanId)) +} + type mockTrace struct{} func (t mockTrace) TraceId() string {