写点什么

Grpc-go 源码刨析

用户头像
王博
关注
发布于: 2021 年 06 月 01 日

服务端

服务端测试代码

测试代码从 grpc-go 仓库中可以看到,从测试代码可以看出,grpc 服务端主要分为以下几步:

1.实例化 Server

2.注册 Service

3.监听并接收连接请求


const (  port = ":50051")
// server is used to implement helloworld.GreeterServer.type server struct { pb.UnimplementedGreeterServer}
// SayHello implements helloworld.GreeterServerfunc (s *server) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) { log.Printf("Received: %v", in.GetName()) return &pb.HelloReply{Message: "Hello " + in.GetName()}, nil}
func main() { lis, err := net.Listen("tcp", port) if err != nil { log.Fatalf("failed to listen: %v", err) } s := grpc.NewServer() pb.RegisterGreeterServer(s, &server{}) if err := s.Serve(lis); err != nil { log.Fatalf("failed to serve: %v", err) }}
复制代码

实例化 Server

func NewServer(opt ...ServerOption) *Server {	opts := defaultServerOptions	//设置定制参数	for _, o := range opt {		o.apply(&opts)	}	//初始化server对象	s := &Server{		lis:      make(map[net.Listener]bool),		opts:     opts,		conns:    make(map[transport.ServerTransport]bool),		services: make(map[string]*serviceInfo),		quit:     grpcsync.NewEvent(),		done:     grpcsync.NewEvent(),		czData:   new(channelzData),	}	chainUnaryServerInterceptors(s)	chainStreamServerInterceptors(s)	s.cv = sync.NewCond(&s.mu)	if EnableTracing {		_, file, line, _ := runtime.Caller(1)		s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))	}
if s.opts.numServerWorkers > 0 { s.initServerWorkers() }
if channelz.IsOn() { s.channelzID = channelz.RegisterServer(&channelzServer{s}, "") } return s}
复制代码

初始化 Server 对象比较简单,Server 主要包含一以下成员:

  • lis:监听地址列表

  • opst:服务选项,可以设置一下基础配置

  • conns:客户端连接列表

  • service:service 列表,一个 server 对应对个 service,一个 service 对应对个方法

  • quit:退出信号

  • done:完成信号

  • czData:用于存储 ClientConn,addrConn 和 Server 的 channelz 相关数据

  • cv:当优雅退出时,会等待这个信号量,直到所有 RPC 请求都处理并断开才会继续处理。

注册 Service

先看一下怎么注册 Service,在 helloworld.pb.go 文件中,会有 RegisterGreeterServer 方法以及 Greeter_ServiceDesc 变量,Greeter_serviceDesc 描述了服务的属性。RegisterGreeterServer 方法会向 gRPC 服务端 s 注册服务 srv。


pb.RegisterGreeterServer(s, &server{})func RegisterGreeterServer(s grpc.ServiceRegistrar, srv GreeterServer) {  s.RegisterService(&Greeter_ServiceDesc, srv)}var Greeter_ServiceDesc = grpc.ServiceDesc{  ServiceName: "helloworld.Greeter",//服务名称  HandlerType: (*GreeterServer)(nil),//服务接口  Methods: []grpc.MethodDesc{//一元方法集    {      MethodName: "SayHello",      Handler:    _Greeter_SayHello_Handler,    },  },  Streams:  []grpc.StreamDesc{},//流式方法集  Metadata: "examples/helloworld/helloworld/helloworld.proto",//元数据}
复制代码


服务注册具体实现方式为,server.go 中的 RegisterService 方法,会判断 ServiceServer 是否实现 sd 中描述的 HandlerType,如果实现了则调用 s.register 方法注册。


func (s *Server) RegisterService(sd *ServiceDesc, ss interface{}) {  if ss != nil {    ht := reflect.TypeOf(sd.HandlerType).Elem()    st := reflect.TypeOf(ss)    if !st.Implements(ht) {      logger.Fatalf("grpc: Server.RegisterService found the handler of type %v that does not satisfy %v", st, ht)    }  }  s.register(sd, ss)}
复制代码


register 根据 Method 创建对应的 map,并将名称作为键,方法描述(指针)作为值,添加到相应的 map 中。最后将{服务名称:服务}添加到 server。(一个 server 对应多个 service)


func (s *Server) register(sd *ServiceDesc, ss interface{}) {  s.mu.Lock()  defer s.mu.Unlock()  s.printf("RegisterService(%q)", sd.ServiceName)  if s.serve {    logger.Fatalf("grpc: Server.RegisterService after Server.Serve for %q", sd.ServiceName)  }  if _, ok := s.services[sd.ServiceName]; ok {    logger.Fatalf("grpc: Server.RegisterService found duplicate service registration for %q", sd.ServiceName)  }  info := &serviceInfo{    serviceImpl: ss,    methods:     make(map[string]*MethodDesc),    streams:     make(map[string]*StreamDesc),    mdata:       sd.Metadata,  }  for i := range sd.Methods {    d := &sd.Methods[i]    info.methods[d.MethodName] = d  }  for i := range sd.Streams {    d := &sd.Streams[i]    info.streams[d.StreamName] = d  }  s.services[sd.ServiceName] = info}
复制代码

监听

监听处理请求核心代码如下:

func (s *Server) Serve(lis net.Listener) error {	s.mu.Lock()	s.printf("serving")	s.serve = true	if s.lis == nil {		// Serve called after Stop or GracefulStop.		s.mu.Unlock()		lis.Close()		return ErrServerStopped	}
s.serveWG.Add(1) defer func() { s.serveWG.Done() if s.quit.HasFired() { // Stop or GracefulStop called; block until done and return nil. <-s.done.Done() } }()
ls := &listenSocket{Listener: lis} s.lis[ls] = true
if channelz.IsOn() { ls.channelzID = channelz.RegisterListenSocket(ls, s.channelzID, lis.Addr().String()) } s.mu.Unlock()
defer func() { s.mu.Lock() if s.lis != nil && s.lis[ls] { ls.Close() delete(s.lis, ls) } s.mu.Unlock() }()
var tempDelay time.Duration // how long to sleep on accept failure // 循环处理连接,每个连接使用一个goroutine处理 // accept如果失败,则下次accept之前睡眠一段时间 for { rawConn, err := lis.Accept() if err != nil { if ne, ok := err.(interface { Temporary() bool }); ok && ne.Temporary() { if tempDelay == 0 { // 初始化5ms tempDelay = 5 * time.Millisecond } else { //否则翻倍 tempDelay *= 2 } //不超过1分钟 if max := 1 * time.Second; tempDelay > max { tempDelay = max } s.mu.Lock() s.printf("Accept error: %v; retrying in %v", err, tempDelay) s.mu.Unlock() // 等待超时重试,或者context事件的发生 timer := time.NewTimer(tempDelay) select { case <-timer.C: case <-s.quit.Done(): timer.Stop() return nil } continue } s.mu.Lock() s.printf("done serving; Accept = %v", err) s.mu.Unlock()
if s.quit.HasFired() { return nil } return err } // 重置延时 tempDelay = 0 // Start a new goroutine to deal with rawConn so we don't stall this Accept // loop goroutine. // // Make sure we account for the goroutine so GracefulStop doesn't nil out // s.conns before this conn can be added. s.serveWG.Add(1) // 每个新的tcp连接使用单独的goroutine处理 go func() { s.handleRawConn(rawConn) s.serveWG.Done() }() }}
复制代码


对于监听处理请求来说,核心实现为:

  • 不断地从 lis.Accept 取出连接,如果返回 error,则触发休眠(没必要返回 error 了还要一直去拿)

  • 休眠策略为,第一次休眠 5ms,不断翻倍,最大 1s(很类似 slice 扩容)

  • 如果监听到请求,那么会重置休眠时间,并用一个 goroutine 去处理请求,也就是说每一个请求都是不同的 goroutine 在处理

  • 加入 waitGroup 用来处理优雅重启或退出,等待所有 goroutine 执行结束之后才会退出

服务端执行调用

在注册 service 时,我们知道,pb.RegisterGreeterServer(s, &server{})传入的第二个参数为我们自定义实现了相应接口的实现类。在 service 注册阶段,我们将方法名作为 key,将 MethodDesc 作为 val 存到 map 里(常规调用存在 info.methods,流式调用存在 sd.ServiceName),所以我们可以根据方法名找到对应的函数(_Greeter_SayHello_Handler)。


var Greeter_ServiceDesc = grpc.ServiceDesc{  ServiceName: "helloworld.Greeter",  HandlerType: (*GreeterServer)(nil),  Methods: []grpc.MethodDesc{    {      MethodName: "SayHello",      Handler:    _Greeter_SayHello_Handler,    },  },  Streams:  []grpc.StreamDesc{},  Metadata: "examples/helloworld/helloworld/helloworld.proto",}
复制代码

客户端

客户端测试代码

const (	address     = "localhost:50051"	defaultName = "world")
func main() { // Set up a connection to the server. conn, err := grpc.Dial(address, grpc.WithInsecure(), grpc.WithBlock()) if err != nil { log.Fatalf("did not connect: %v", err) } defer conn.Close() c := pb.NewGreeterClient(conn)
// Contact the server and print out its response. name := defaultName if len(os.Args) > 1 { name = os.Args[1] } ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() r, err := c.SayHello(ctx, &pb.HelloRequest{Name: name}) if err != nil { log.Fatalf("could not greet: %v", err) } log.Printf("Greeting: %s", r.GetMessage())}
复制代码

建立拨号连接

func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) {	cc := &ClientConn{		target:            target,		csMgr:             &connectivityStateManager{},		conns:             make(map[*addrConn]struct{}),		dopts:             defaultDialOptions(),		blockingpicker:    newPickerWrapper(),		czData:            new(channelzData),		firstResolveEvent: grpcsync.NewEvent(),	}	cc.retryThrottler.Store((*retryThrottler)(nil))	cc.ctx, cc.cancel = context.WithCancel(context.Background())
for _, opt := range opts { opt.apply(&cc.dopts) }
chainUnaryClientInterceptors(cc) chainStreamClientInterceptors(cc)
defer func() { if err != nil { cc.Close() } }()
if channelz.IsOn() { if cc.dopts.channelzParentID != 0 { cc.channelzID = channelz.RegisterChannel(&channelzChannel{cc}, cc.dopts.channelzParentID, target) channelz.AddTraceEvent(logger, cc.channelzID, 0, &channelz.TraceEventDesc{ Desc: "Channel Created", Severity: channelz.CtInfo, Parent: &channelz.TraceEventDesc{ Desc: fmt.Sprintf("Nested Channel(id:%d) created", cc.channelzID), Severity: channelz.CtInfo, }, }) } else { cc.channelzID = channelz.RegisterChannel(&channelzChannel{cc}, 0, target) channelz.Info(logger, cc.channelzID, "Channel Created") } cc.csMgr.channelzID = cc.channelzID }
if !cc.dopts.insecure { if cc.dopts.copts.TransportCredentials == nil && cc.dopts.copts.CredsBundle == nil { return nil, errNoTransportSecurity } if cc.dopts.copts.TransportCredentials != nil && cc.dopts.copts.CredsBundle != nil { return nil, errTransportCredsAndBundle } } else { if cc.dopts.copts.TransportCredentials != nil || cc.dopts.copts.CredsBundle != nil { return nil, errCredentialsConflict } for _, cd := range cc.dopts.copts.PerRPCCredentials { if cd.RequireTransportSecurity() { return nil, errTransportCredentialsMissing } } }
if cc.dopts.defaultServiceConfigRawJSON != nil { scpr := parseServiceConfig(*cc.dopts.defaultServiceConfigRawJSON) if scpr.Err != nil { return nil, fmt.Errorf("%s: %v", invalidDefaultServiceConfigErrPrefix, scpr.Err) } cc.dopts.defaultServiceConfig, _ = scpr.Config.(*ServiceConfig) } cc.mkp = cc.dopts.copts.KeepaliveParams
if cc.dopts.copts.UserAgent != "" { cc.dopts.copts.UserAgent += " " + grpcUA } else { cc.dopts.copts.UserAgent = grpcUA }
if cc.dopts.timeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, cc.dopts.timeout) defer cancel() } defer func() { select { case <-ctx.Done(): switch { case ctx.Err() == err: conn = nil case err == nil || !cc.dopts.returnLastError: conn, err = nil, ctx.Err() default: conn, err = nil, fmt.Errorf("%v: %v", ctx.Err(), err) } default: } }()
scSet := false if cc.dopts.scChan != nil { // Try to get an initial service config. select { case sc, ok := <-cc.dopts.scChan: if ok { cc.sc = &sc cc.safeConfigSelector.UpdateConfigSelector(&defaultConfigSelector{&sc}) scSet = true } default: } } if cc.dopts.bs == nil { cc.dopts.bs = backoff.DefaultExponential }
// Determine the resolver to use. cc.parsedTarget = grpcutil.ParseTarget(cc.target, cc.dopts.copts.Dialer != nil) channelz.Infof(logger, cc.channelzID, "parsed scheme: %q", cc.parsedTarget.Scheme) resolverBuilder := cc.getResolver(cc.parsedTarget.Scheme) if resolverBuilder == nil { // If resolver builder is still nil, the parsed target's scheme is // not registered. Fallback to default resolver and set Endpoint to // the original target. channelz.Infof(logger, cc.channelzID, "scheme %q not registered, fallback to default scheme", cc.parsedTarget.Scheme) cc.parsedTarget = resolver.Target{ Scheme: resolver.GetDefaultScheme(), Endpoint: target, } resolverBuilder = cc.getResolver(cc.parsedTarget.Scheme) if resolverBuilder == nil { return nil, fmt.Errorf("could not get resolver for default scheme: %q", cc.parsedTarget.Scheme) } }
creds := cc.dopts.copts.TransportCredentials if creds != nil && creds.Info().ServerName != "" { cc.authority = creds.Info().ServerName } else if cc.dopts.insecure && cc.dopts.authority != "" { cc.authority = cc.dopts.authority } else if strings.HasPrefix(cc.target, "unix:") || strings.HasPrefix(cc.target, "unix-abstract:") { cc.authority = "localhost" } else if strings.HasPrefix(cc.parsedTarget.Endpoint, ":") { cc.authority = "localhost" + cc.parsedTarget.Endpoint } else { // Use endpoint from "scheme://authority/endpoint" as the default // authority for ClientConn. cc.authority = cc.parsedTarget.Endpoint }
if cc.dopts.scChan != nil && !scSet { // Blocking wait for the initial service config. select { case sc, ok := <-cc.dopts.scChan: if ok { cc.sc = &sc cc.safeConfigSelector.UpdateConfigSelector(&defaultConfigSelector{&sc}) } case <-ctx.Done(): return nil, ctx.Err() } } if cc.dopts.scChan != nil { go cc.scWatcher() }
var credsClone credentials.TransportCredentials if creds := cc.dopts.copts.TransportCredentials; creds != nil { credsClone = creds.Clone() } cc.balancerBuildOpts = balancer.BuildOptions{ DialCreds: credsClone, CredsBundle: cc.dopts.copts.CredsBundle, Dialer: cc.dopts.copts.Dialer, CustomUserAgent: cc.dopts.copts.UserAgent, ChannelzParentID: cc.channelzID, Target: cc.parsedTarget, }
// Build the resolver. rWrapper, err := newCCResolverWrapper(cc, resolverBuilder) if err != nil { return nil, fmt.Errorf("failed to build resolver: %v", err) } cc.mu.Lock() cc.resolverWrapper = rWrapper cc.mu.Unlock()
// A blocking dial blocks until the clientConn is ready. if cc.dopts.block { for { s := cc.GetState() if s == connectivity.Ready { break } else if cc.dopts.copts.FailOnNonTempDialError && s == connectivity.TransientFailure { if err = cc.connectionError(); err != nil { terr, ok := err.(interface { Temporary() bool }) if ok && !terr.Temporary() { return nil, err } } } if !cc.WaitForStateChange(ctx, s) { // ctx got timeout or canceled. if err = cc.connectionError(); err != nil && cc.dopts.returnLastError { return nil, err } return nil, ctx.Err() } } }
return cc, nil}
复制代码

grpc.Dial 实际上是封装了 grpc.DialContext,主要是承担了以下职责:

  • 初始化 ClientConn 对象

  • 初始化重试规则

  • 执行一些可选方法

  • 初始化一元/流式拦截器(比较坑的是 grpc 只支持一个拦截器,如果有多个只会取第一个)

  • 初始化负载均衡策略

  • 初始化并解析地址信息

  • 建立连接

初始化 client 对象

func NewGreeterClient(cc grpc.ClientConnInterface) GreeterClient {	return &greeterClient{cc}}
复制代码

这里只是把拨号连接传给 client,比较简单,没什么好说的

调用

func (c *greeterClient) SayHello(ctx context.Context, in *HelloRequest, opts ...grpc.CallOption) (*HelloReply, error) {	out := new(HelloReply)	err := c.cc.Invoke(ctx, "/helloworld.Greeter/SayHello", in, out, opts...)	if err != nil {		return nil, err	}	return out, nil}func (cc *ClientConn) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...CallOption) error {	// allow interceptor to see all applicable call options, which means those	// configured as defaults from dial option as well as per-call options	opts = combine(cc.dopts.callOptions, opts)
if cc.dopts.unaryInt != nil { return cc.dopts.unaryInt(ctx, method, args, reply, cc, invoke, opts...) } return invoke(ctx, method, args, reply, cc, opts...)}
复制代码

可以看到,会使用前面主要是做一下数组组装工作,最后会调用 invoke 方法

func invoke(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error {	cs, err := newClientStream(ctx, unaryStreamDesc, cc, method, opts...)	if err != nil {		return err	}	if err := cs.SendMsg(req); err != nil {		return err	}	return cs.RecvMsg(reply)}
复制代码

invoke 方法主要包括三部分:

  • newClientStream:获取传输层 Trasport 并组合封装到 ClientStream 中返回,在这块会涉及负载均衡、超时控制等操作

  • SendMsg:发送 RPC 请求

  • RecvMsg:阻塞等待接受到的 RPC 方法响应结果并返回。

发布于: 2021 年 06 月 01 日阅读数: 27
用户头像

王博

关注

我是一名后端,写代码的憨憨 2018.12.29 加入

还未添加个人简介

评论

发布
暂无评论
Grpc-go源码刨析