You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-zero/zrpc/internal/serverinterceptors/authinterceptor_test.go

198 lines
4.2 KiB
Go

4 years ago
package serverinterceptors
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
4 years ago
"github.com/tal-tech/go-zero/core/stores/redistest"
"github.com/tal-tech/go-zero/zrpc/internal/auth"
4 years ago
"google.golang.org/grpc"
4 years ago
"google.golang.org/grpc/metadata"
)
4 years ago
func TestStreamAuthorizeInterceptor(t *testing.T) {
tests := []struct {
name string
app string
token string
strict bool
hasError bool
}{
{
name: "strict=false",
strict: false,
hasError: false,
},
{
name: "strict=true",
strict: true,
hasError: true,
},
{
name: "strict=true,with token",
app: "foo",
token: "bar",
strict: true,
hasError: false,
},
{
name: "strict=true,with error token",
app: "foo",
token: "error",
strict: true,
hasError: true,
},
}
4 years ago
store, clean, err := redistest.CreateRedis()
assert.Nil(t, err)
defer clean()
4 years ago
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if len(test.app) > 0 {
assert.Nil(t, store.Hset("apps", test.app, test.token))
defer store.Hdel("apps", test.app)
}
authenticator, err := auth.NewAuthenticator(store, "apps", test.strict)
assert.Nil(t, err)
interceptor := StreamAuthorizeInterceptor(authenticator)
md := metadata.New(map[string]string{
"app": "foo",
"token": "bar",
})
ctx := metadata.NewIncomingContext(context.Background(), md)
stream := mockedStream{ctx: ctx}
err = interceptor(nil, stream, nil, func(srv interface{}, stream grpc.ServerStream) error {
return nil
})
if test.hasError {
assert.NotNil(t, err)
} else {
assert.Nil(t, err)
}
})
}
}
4 years ago
func TestUnaryAuthorizeInterceptor(t *testing.T) {
tests := []struct {
4 years ago
name string
app string
token string
strict bool
hasError bool
4 years ago
}{
{
4 years ago
name: "strict=false",
strict: false,
hasError: false,
},
{
name: "strict=true",
strict: true,
hasError: true,
},
{
name: "strict=true,with token",
app: "foo",
token: "bar",
strict: true,
hasError: false,
4 years ago
},
{
4 years ago
name: "strict=true,with error token",
app: "foo",
token: "error",
strict: true,
hasError: true,
4 years ago
},
}
4 years ago
store, clean, err := redistest.CreateRedis()
assert.Nil(t, err)
defer clean()
4 years ago
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
4 years ago
if len(test.app) > 0 {
assert.Nil(t, store.Hset("apps", test.app, test.token))
defer store.Hdel("apps", test.app)
}
4 years ago
authenticator, err := auth.NewAuthenticator(store, "apps", test.strict)
assert.Nil(t, err)
interceptor := UnaryAuthorizeInterceptor(authenticator)
md := metadata.New(map[string]string{
4 years ago
"app": "foo",
"token": "bar",
4 years ago
})
ctx := metadata.NewIncomingContext(context.Background(), md)
_, err = interceptor(ctx, nil, nil,
func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
4 years ago
if test.hasError {
4 years ago
assert.NotNil(t, err)
} else {
assert.Nil(t, err)
}
4 years ago
if test.strict {
_, err = interceptor(context.Background(), nil, nil,
func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
assert.NotNil(t, err)
var md metadata.MD
ctx := metadata.NewIncomingContext(context.Background(), md)
_, err = interceptor(ctx, nil, nil,
func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
assert.NotNil(t, err)
md = metadata.New(map[string]string{
"app": "",
"token": "",
})
ctx = metadata.NewIncomingContext(context.Background(), md)
_, err = interceptor(ctx, nil, nil,
func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
assert.NotNil(t, err)
}
4 years ago
})
}
}
4 years ago
type mockedStream struct {
ctx context.Context
}
func (m mockedStream) SetHeader(md metadata.MD) error {
return nil
}
func (m mockedStream) SendHeader(md metadata.MD) error {
return nil
}
func (m mockedStream) SetTrailer(md metadata.MD) {
}
func (m mockedStream) Context() context.Context {
return m.ctx
}
func (m mockedStream) SendMsg(v interface{}) error {
return nil
}
func (m mockedStream) RecvMsg(v interface{}) error {
return nil
}