You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-zero/rpcx/internal/chainserverinterceptors.go

82 lines
2.2 KiB
Go

4 years ago
package internal
4 years ago
import (
"context"
"google.golang.org/grpc"
)
func WithStreamServerInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.ServerOption {
return grpc.StreamInterceptor(chainStreamServerInterceptors(interceptors...))
}
func WithUnaryServerInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.ServerOption {
return grpc.UnaryInterceptor(chainUnaryServerInterceptors(interceptors...))
}
func chainStreamServerInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor {
switch len(interceptors) {
case 0:
return func(srv interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo,
handler grpc.StreamHandler) error {
return handler(srv, stream)
}
case 1:
return interceptors[0]
default:
last := len(interceptors) - 1
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo,
handler grpc.StreamHandler) error {
var chainHandler grpc.StreamHandler
var current int
chainHandler = func(curSrv interface{}, curStream grpc.ServerStream) error {
if current == last {
return handler(curSrv, curStream)
}
current++
err := interceptors[current](curSrv, curStream, info, chainHandler)
current--
return err
}
return interceptors[0](srv, stream, info, chainHandler)
}
}
}
func chainUnaryServerInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
switch len(interceptors) {
case 0:
return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
interface{}, error) {
return handler(ctx, req)
}
case 1:
return interceptors[0]
default:
last := len(interceptors) - 1
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
interface{}, error) {
var chainHandler grpc.UnaryHandler
var current int
chainHandler = func(curCtx context.Context, curReq interface{}) (interface{}, error) {
if current == last {
return handler(curCtx, curReq)
}
current++
resp, err := interceptors[current](curCtx, curReq, info, chainHandler)
current--
return resp, err
}
return interceptors[0](ctx, req, info, chainHandler)
}
}
}