From dc286a03f58f2eaf9cc0f29f4c3df8550edb56bc Mon Sep 17 00:00:00 2001 From: kevin Date: Sun, 23 Aug 2020 15:53:10 +0800 Subject: [PATCH] add more tests --- rpcx/internal/auth/credential_test.go | 62 +++++++ .../authinterceptor_test.go | 163 +++++++++++++++++- 2 files changed, 216 insertions(+), 9 deletions(-) create mode 100644 rpcx/internal/auth/credential_test.go diff --git a/rpcx/internal/auth/credential_test.go b/rpcx/internal/auth/credential_test.go new file mode 100644 index 00000000..f80a0e86 --- /dev/null +++ b/rpcx/internal/auth/credential_test.go @@ -0,0 +1,62 @@ +package auth + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/metadata" +) + +func TestParseCredential(t *testing.T) { + tests := []struct { + name string + withNil bool + withEmptyMd bool + app string + token string + }{ + { + name: "nil", + withNil: true, + }, + { + name: "empty md", + withEmptyMd: true, + }, + { + name: "empty", + }, + { + name: "valid", + app: "foo", + token: "bar", + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + var ctx context.Context + if test.withNil { + ctx = context.Background() + } else if test.withEmptyMd { + ctx = metadata.NewIncomingContext(context.Background(), metadata.MD{}) + } else { + md := metadata.New(map[string]string{ + "app": test.app, + "token": test.token, + }) + ctx = metadata.NewIncomingContext(context.Background(), md) + } + cred := ParseCredential(ctx) + assert.False(t, cred.RequireTransportSecurity()) + m, err := cred.GetRequestMetadata(context.Background()) + assert.Nil(t, err) + assert.Equal(t, test.app, m[appKey]) + assert.Equal(t, test.token, m[tokenKey]) + }) + } +} diff --git a/rpcx/internal/serverinterceptors/authinterceptor_test.go b/rpcx/internal/serverinterceptors/authinterceptor_test.go index 9932fc62..987e9194 100644 --- a/rpcx/internal/serverinterceptors/authinterceptor_test.go +++ b/rpcx/internal/serverinterceptors/authinterceptor_test.go @@ -8,21 +8,108 @@ import ( "github.com/stretchr/testify/assert" "github.com/tal-tech/go-zero/core/stores/redis" "github.com/tal-tech/go-zero/rpcx/internal/auth" + "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) +func TestStreamAuthorizeInterceptor(t *testing.T) { + tests := []struct { + name string + app string + token string + strict bool + hasError bool + }{ + { + name: "strict=false", + strict: false, + hasError: false, + }, + { + name: "strict=true", + strict: true, + hasError: true, + }, + { + name: "strict=true,with token", + app: "foo", + token: "bar", + strict: true, + hasError: false, + }, + { + name: "strict=true,with error token", + app: "foo", + token: "error", + strict: true, + hasError: true, + }, + } + + r := miniredis.NewMiniRedis() + assert.Nil(t, r.Start()) + defer r.Close() + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + store := redis.NewRedis(r.Addr(), redis.NodeType) + if len(test.app) > 0 { + assert.Nil(t, store.Hset("apps", test.app, test.token)) + defer store.Hdel("apps", test.app) + } + + authenticator, err := auth.NewAuthenticator(store, "apps", test.strict) + assert.Nil(t, err) + interceptor := StreamAuthorizeInterceptor(authenticator) + md := metadata.New(map[string]string{ + "app": "foo", + "token": "bar", + }) + ctx := metadata.NewIncomingContext(context.Background(), md) + stream := mockedStream{ctx: ctx} + err = interceptor(nil, stream, nil, func(srv interface{}, stream grpc.ServerStream) error { + return nil + }) + if test.hasError { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + } + }) + } +} + func TestUnaryAuthorizeInterceptor(t *testing.T) { tests := []struct { - name string - strict bool + name string + app string + token string + strict bool + hasError bool }{ { - name: "strict=true", - strict: true, + name: "strict=false", + strict: false, + hasError: false, + }, + { + name: "strict=true", + strict: true, + hasError: true, + }, + { + name: "strict=true,with token", + app: "foo", + token: "bar", + strict: true, + hasError: false, }, { - name: "strict=false", - strict: false, + name: "strict=true,with error token", + app: "foo", + token: "error", + strict: true, + hasError: true, }, } @@ -33,23 +120,81 @@ func TestUnaryAuthorizeInterceptor(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { store := redis.NewRedis(r.Addr(), redis.NodeType) + if len(test.app) > 0 { + assert.Nil(t, store.Hset("apps", test.app, test.token)) + defer store.Hdel("apps", test.app) + } + authenticator, err := auth.NewAuthenticator(store, "apps", test.strict) assert.Nil(t, err) interceptor := UnaryAuthorizeInterceptor(authenticator) md := metadata.New(map[string]string{ - "app": "name", - "token": "key", + "app": "foo", + "token": "bar", }) ctx := metadata.NewIncomingContext(context.Background(), md) _, err = interceptor(ctx, nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) { return nil, nil }) - if test.strict { + if test.hasError { assert.NotNil(t, err) } else { assert.Nil(t, err) } + if test.strict { + _, err = interceptor(context.Background(), nil, nil, + func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + assert.NotNil(t, err) + + var md metadata.MD + ctx := metadata.NewIncomingContext(context.Background(), md) + _, err = interceptor(ctx, nil, nil, + func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + assert.NotNil(t, err) + + md = metadata.New(map[string]string{ + "app": "", + "token": "", + }) + ctx = metadata.NewIncomingContext(context.Background(), md) + _, err = interceptor(ctx, nil, nil, + func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + assert.NotNil(t, err) + } }) } } + +type mockedStream struct { + ctx context.Context +} + +func (m mockedStream) SetHeader(md metadata.MD) error { + return nil +} + +func (m mockedStream) SendHeader(md metadata.MD) error { + return nil +} + +func (m mockedStream) SetTrailer(md metadata.MD) { +} + +func (m mockedStream) Context() context.Context { + return m.ctx +} + +func (m mockedStream) SendMsg(v interface{}) error { + return nil +} + +func (m mockedStream) RecvMsg(v interface{}) error { + return nil +}