Grpc-go 源码刨析
服务端
服务端测试代码
测试代码从 grpc-go 仓库中可以看到,从测试代码可以看出,grpc 服务端主要分为以下几步:
1.实例化 Server
2.注册 Service
3.监听并接收连接请求
const ( port = ":50051")
// server is used to implement helloworld.GreeterServer.type server struct { pb.UnimplementedGreeterServer}
// SayHello implements helloworld.GreeterServerfunc (s *server) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) { log.Printf("Received: %v", in.GetName()) return &pb.HelloReply{Message: "Hello " + in.GetName()}, nil}
func main() { lis, err := net.Listen("tcp", port) if err != nil { log.Fatalf("failed to listen: %v", err) } s := grpc.NewServer() pb.RegisterGreeterServer(s, &server{}) if err := s.Serve(lis); err != nil { log.Fatalf("failed to serve: %v", err) }}
实例化 Server
func NewServer(opt ...ServerOption) *Server { opts := defaultServerOptions //设置定制参数 for _, o := range opt { o.apply(&opts) } //初始化server对象 s := &Server{ lis: make(map[net.Listener]bool), opts: opts, conns: make(map[transport.ServerTransport]bool), services: make(map[string]*serviceInfo), quit: grpcsync.NewEvent(), done: grpcsync.NewEvent(), czData: new(channelzData), } chainUnaryServerInterceptors(s) chainStreamServerInterceptors(s) s.cv = sync.NewCond(&s.mu) if EnableTracing { _, file, line, _ := runtime.Caller(1) s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line)) }
if s.opts.numServerWorkers > 0 { s.initServerWorkers() }
if channelz.IsOn() { s.channelzID = channelz.RegisterServer(&channelzServer{s}, "") } return s}初始化 Server 对象比较简单,Server 主要包含一以下成员:
lis:监听地址列表
opst:服务选项,可以设置一下基础配置
conns:客户端连接列表
service:service 列表,一个 server 对应对个 service,一个 service 对应对个方法
quit:退出信号
done:完成信号
czData:用于存储 ClientConn,addrConn 和 Server 的 channelz 相关数据
cv:当优雅退出时,会等待这个信号量,直到所有 RPC 请求都处理并断开才会继续处理。
注册 Service
先看一下怎么注册 Service,在 helloworld.pb.go 文件中,会有 RegisterGreeterServer 方法以及 Greeter_ServiceDesc 变量,Greeter_serviceDesc 描述了服务的属性。RegisterGreeterServer 方法会向 gRPC 服务端 s 注册服务 srv。
pb.RegisterGreeterServer(s, &server{})func RegisterGreeterServer(s grpc.ServiceRegistrar, srv GreeterServer) { s.RegisterService(&Greeter_ServiceDesc, srv)}var Greeter_ServiceDesc = grpc.ServiceDesc{ ServiceName: "helloworld.Greeter",//服务名称 HandlerType: (*GreeterServer)(nil),//服务接口 Methods: []grpc.MethodDesc{//一元方法集 { MethodName: "SayHello", Handler: _Greeter_SayHello_Handler, }, }, Streams: []grpc.StreamDesc{},//流式方法集 Metadata: "examples/helloworld/helloworld/helloworld.proto",//元数据}
服务注册具体实现方式为,server.go 中的 RegisterService 方法,会判断 ServiceServer 是否实现 sd 中描述的 HandlerType,如果实现了则调用 s.register 方法注册。
func (s *Server) RegisterService(sd *ServiceDesc, ss interface{}) { if ss != nil { ht := reflect.TypeOf(sd.HandlerType).Elem() st := reflect.TypeOf(ss) if !st.Implements(ht) { logger.Fatalf("grpc: Server.RegisterService found the handler of type %v that does not satisfy %v", st, ht) } } s.register(sd, ss)}
register 根据 Method 创建对应的 map,并将名称作为键,方法描述(指针)作为值,添加到相应的 map 中。最后将{服务名称:服务}添加到 server。(一个 server 对应多个 service)
func (s *Server) register(sd *ServiceDesc, ss interface{}) { s.mu.Lock() defer s.mu.Unlock() s.printf("RegisterService(%q)", sd.ServiceName) if s.serve { logger.Fatalf("grpc: Server.RegisterService after Server.Serve for %q", sd.ServiceName) } if _, ok := s.services[sd.ServiceName]; ok { logger.Fatalf("grpc: Server.RegisterService found duplicate service registration for %q", sd.ServiceName) } info := &serviceInfo{ serviceImpl: ss, methods: make(map[string]*MethodDesc), streams: make(map[string]*StreamDesc), mdata: sd.Metadata, } for i := range sd.Methods { d := &sd.Methods[i] info.methods[d.MethodName] = d } for i := range sd.Streams { d := &sd.Streams[i] info.streams[d.StreamName] = d } s.services[sd.ServiceName] = info}
监听
监听处理请求核心代码如下:
func (s *Server) Serve(lis net.Listener) error { s.mu.Lock() s.printf("serving") s.serve = true if s.lis == nil { // Serve called after Stop or GracefulStop. s.mu.Unlock() lis.Close() return ErrServerStopped }
s.serveWG.Add(1) defer func() { s.serveWG.Done() if s.quit.HasFired() { // Stop or GracefulStop called; block until done and return nil. <-s.done.Done() } }()
ls := &listenSocket{Listener: lis} s.lis[ls] = true
if channelz.IsOn() { ls.channelzID = channelz.RegisterListenSocket(ls, s.channelzID, lis.Addr().String()) } s.mu.Unlock()
defer func() { s.mu.Lock() if s.lis != nil && s.lis[ls] { ls.Close() delete(s.lis, ls) } s.mu.Unlock() }()
var tempDelay time.Duration // how long to sleep on accept failure // 循环处理连接,每个连接使用一个goroutine处理 // accept如果失败,则下次accept之前睡眠一段时间 for { rawConn, err := lis.Accept() if err != nil { if ne, ok := err.(interface { Temporary() bool }); ok && ne.Temporary() { if tempDelay == 0 { // 初始化5ms tempDelay = 5 * time.Millisecond } else { //否则翻倍 tempDelay *= 2 } //不超过1分钟 if max := 1 * time.Second; tempDelay > max { tempDelay = max } s.mu.Lock() s.printf("Accept error: %v; retrying in %v", err, tempDelay) s.mu.Unlock() // 等待超时重试,或者context事件的发生 timer := time.NewTimer(tempDelay) select { case <-timer.C: case <-s.quit.Done(): timer.Stop() return nil } continue } s.mu.Lock() s.printf("done serving; Accept = %v", err) s.mu.Unlock()
if s.quit.HasFired() { return nil } return err } // 重置延时 tempDelay = 0 // Start a new goroutine to deal with rawConn so we don't stall this Accept // loop goroutine. // // Make sure we account for the goroutine so GracefulStop doesn't nil out // s.conns before this conn can be added. s.serveWG.Add(1) // 每个新的tcp连接使用单独的goroutine处理 go func() { s.handleRawConn(rawConn) s.serveWG.Done() }() }}对于监听处理请求来说,核心实现为:
不断地从 lis.Accept 取出连接,如果返回 error,则触发休眠(没必要返回 error 了还要一直去拿)
休眠策略为,第一次休眠 5ms,不断翻倍,最大 1s(很类似 slice 扩容)
如果监听到请求,那么会重置休眠时间,并用一个 goroutine 去处理请求,也就是说每一个请求都是不同的 goroutine 在处理
加入 waitGroup 用来处理优雅重启或退出,等待所有 goroutine 执行结束之后才会退出
服务端执行调用
在注册 service 时,我们知道,pb.RegisterGreeterServer(s, &server{})传入的第二个参数为我们自定义实现了相应接口的实现类。在 service 注册阶段,我们将方法名作为 key,将 MethodDesc 作为 val 存到 map 里(常规调用存在 info.methods,流式调用存在 sd.ServiceName),所以我们可以根据方法名找到对应的函数(_Greeter_SayHello_Handler)。
var Greeter_ServiceDesc = grpc.ServiceDesc{ ServiceName: "helloworld.Greeter", HandlerType: (*GreeterServer)(nil), Methods: []grpc.MethodDesc{ { MethodName: "SayHello", Handler: _Greeter_SayHello_Handler, }, }, Streams: []grpc.StreamDesc{}, Metadata: "examples/helloworld/helloworld/helloworld.proto",}
客户端
客户端测试代码
const ( address = "localhost:50051" defaultName = "world")
func main() { // Set up a connection to the server. conn, err := grpc.Dial(address, grpc.WithInsecure(), grpc.WithBlock()) if err != nil { log.Fatalf("did not connect: %v", err) } defer conn.Close() c := pb.NewGreeterClient(conn)
// Contact the server and print out its response. name := defaultName if len(os.Args) > 1 { name = os.Args[1] } ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() r, err := c.SayHello(ctx, &pb.HelloRequest{Name: name}) if err != nil { log.Fatalf("could not greet: %v", err) } log.Printf("Greeting: %s", r.GetMessage())}建立拨号连接
func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) { cc := &ClientConn{ target: target, csMgr: &connectivityStateManager{}, conns: make(map[*addrConn]struct{}), dopts: defaultDialOptions(), blockingpicker: newPickerWrapper(), czData: new(channelzData), firstResolveEvent: grpcsync.NewEvent(), } cc.retryThrottler.Store((*retryThrottler)(nil)) cc.ctx, cc.cancel = context.WithCancel(context.Background())
for _, opt := range opts { opt.apply(&cc.dopts) }
chainUnaryClientInterceptors(cc) chainStreamClientInterceptors(cc)
defer func() { if err != nil { cc.Close() } }()
if channelz.IsOn() { if cc.dopts.channelzParentID != 0 { cc.channelzID = channelz.RegisterChannel(&channelzChannel{cc}, cc.dopts.channelzParentID, target) channelz.AddTraceEvent(logger, cc.channelzID, 0, &channelz.TraceEventDesc{ Desc: "Channel Created", Severity: channelz.CtInfo, Parent: &channelz.TraceEventDesc{ Desc: fmt.Sprintf("Nested Channel(id:%d) created", cc.channelzID), Severity: channelz.CtInfo, }, }) } else { cc.channelzID = channelz.RegisterChannel(&channelzChannel{cc}, 0, target) channelz.Info(logger, cc.channelzID, "Channel Created") } cc.csMgr.channelzID = cc.channelzID }
if !cc.dopts.insecure { if cc.dopts.copts.TransportCredentials == nil && cc.dopts.copts.CredsBundle == nil { return nil, errNoTransportSecurity } if cc.dopts.copts.TransportCredentials != nil && cc.dopts.copts.CredsBundle != nil { return nil, errTransportCredsAndBundle } } else { if cc.dopts.copts.TransportCredentials != nil || cc.dopts.copts.CredsBundle != nil { return nil, errCredentialsConflict } for _, cd := range cc.dopts.copts.PerRPCCredentials { if cd.RequireTransportSecurity() { return nil, errTransportCredentialsMissing } } }
if cc.dopts.defaultServiceConfigRawJSON != nil { scpr := parseServiceConfig(*cc.dopts.defaultServiceConfigRawJSON) if scpr.Err != nil { return nil, fmt.Errorf("%s: %v", invalidDefaultServiceConfigErrPrefix, scpr.Err) } cc.dopts.defaultServiceConfig, _ = scpr.Config.(*ServiceConfig) } cc.mkp = cc.dopts.copts.KeepaliveParams
if cc.dopts.copts.UserAgent != "" { cc.dopts.copts.UserAgent += " " + grpcUA } else { cc.dopts.copts.UserAgent = grpcUA }
if cc.dopts.timeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, cc.dopts.timeout) defer cancel() } defer func() { select { case <-ctx.Done(): switch { case ctx.Err() == err: conn = nil case err == nil || !cc.dopts.returnLastError: conn, err = nil, ctx.Err() default: conn, err = nil, fmt.Errorf("%v: %v", ctx.Err(), err) } default: } }()
scSet := false if cc.dopts.scChan != nil { // Try to get an initial service config. select { case sc, ok := <-cc.dopts.scChan: if ok { cc.sc = &sc cc.safeConfigSelector.UpdateConfigSelector(&defaultConfigSelector{&sc}) scSet = true } default: } } if cc.dopts.bs == nil { cc.dopts.bs = backoff.DefaultExponential }
// Determine the resolver to use. cc.parsedTarget = grpcutil.ParseTarget(cc.target, cc.dopts.copts.Dialer != nil) channelz.Infof(logger, cc.channelzID, "parsed scheme: %q", cc.parsedTarget.Scheme) resolverBuilder := cc.getResolver(cc.parsedTarget.Scheme) if resolverBuilder == nil { // If resolver builder is still nil, the parsed target's scheme is // not registered. Fallback to default resolver and set Endpoint to // the original target. channelz.Infof(logger, cc.channelzID, "scheme %q not registered, fallback to default scheme", cc.parsedTarget.Scheme) cc.parsedTarget = resolver.Target{ Scheme: resolver.GetDefaultScheme(), Endpoint: target, } resolverBuilder = cc.getResolver(cc.parsedTarget.Scheme) if resolverBuilder == nil { return nil, fmt.Errorf("could not get resolver for default scheme: %q", cc.parsedTarget.Scheme) } }
creds := cc.dopts.copts.TransportCredentials if creds != nil && creds.Info().ServerName != "" { cc.authority = creds.Info().ServerName } else if cc.dopts.insecure && cc.dopts.authority != "" { cc.authority = cc.dopts.authority } else if strings.HasPrefix(cc.target, "unix:") || strings.HasPrefix(cc.target, "unix-abstract:") { cc.authority = "localhost" } else if strings.HasPrefix(cc.parsedTarget.Endpoint, ":") { cc.authority = "localhost" + cc.parsedTarget.Endpoint } else { // Use endpoint from "scheme://authority/endpoint" as the default // authority for ClientConn. cc.authority = cc.parsedTarget.Endpoint }
if cc.dopts.scChan != nil && !scSet { // Blocking wait for the initial service config. select { case sc, ok := <-cc.dopts.scChan: if ok { cc.sc = &sc cc.safeConfigSelector.UpdateConfigSelector(&defaultConfigSelector{&sc}) } case <-ctx.Done(): return nil, ctx.Err() } } if cc.dopts.scChan != nil { go cc.scWatcher() }
var credsClone credentials.TransportCredentials if creds := cc.dopts.copts.TransportCredentials; creds != nil { credsClone = creds.Clone() } cc.balancerBuildOpts = balancer.BuildOptions{ DialCreds: credsClone, CredsBundle: cc.dopts.copts.CredsBundle, Dialer: cc.dopts.copts.Dialer, CustomUserAgent: cc.dopts.copts.UserAgent, ChannelzParentID: cc.channelzID, Target: cc.parsedTarget, }
// Build the resolver. rWrapper, err := newCCResolverWrapper(cc, resolverBuilder) if err != nil { return nil, fmt.Errorf("failed to build resolver: %v", err) } cc.mu.Lock() cc.resolverWrapper = rWrapper cc.mu.Unlock()
// A blocking dial blocks until the clientConn is ready. if cc.dopts.block { for { s := cc.GetState() if s == connectivity.Ready { break } else if cc.dopts.copts.FailOnNonTempDialError && s == connectivity.TransientFailure { if err = cc.connectionError(); err != nil { terr, ok := err.(interface { Temporary() bool }) if ok && !terr.Temporary() { return nil, err } } } if !cc.WaitForStateChange(ctx, s) { // ctx got timeout or canceled. if err = cc.connectionError(); err != nil && cc.dopts.returnLastError { return nil, err } return nil, ctx.Err() } } }
return cc, nil}grpc.Dial 实际上是封装了 grpc.DialContext,主要是承担了以下职责:
初始化 ClientConn 对象
初始化重试规则
执行一些可选方法
初始化一元/流式拦截器(比较坑的是 grpc 只支持一个拦截器,如果有多个只会取第一个)
初始化负载均衡策略
初始化并解析地址信息
建立连接
初始化 client 对象
func NewGreeterClient(cc grpc.ClientConnInterface) GreeterClient { return &greeterClient{cc}}这里只是把拨号连接传给 client,比较简单,没什么好说的
调用
func (c *greeterClient) SayHello(ctx context.Context, in *HelloRequest, opts ...grpc.CallOption) (*HelloReply, error) { out := new(HelloReply) err := c.cc.Invoke(ctx, "/helloworld.Greeter/SayHello", in, out, opts...) if err != nil { return nil, err } return out, nil}func (cc *ClientConn) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...CallOption) error { // allow interceptor to see all applicable call options, which means those // configured as defaults from dial option as well as per-call options opts = combine(cc.dopts.callOptions, opts)
if cc.dopts.unaryInt != nil { return cc.dopts.unaryInt(ctx, method, args, reply, cc, invoke, opts...) } return invoke(ctx, method, args, reply, cc, opts...)}可以看到,会使用前面主要是做一下数组组装工作,最后会调用 invoke 方法
func invoke(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error { cs, err := newClientStream(ctx, unaryStreamDesc, cc, method, opts...) if err != nil { return err } if err := cs.SendMsg(req); err != nil { return err } return cs.RecvMsg(reply)}invoke 方法主要包括三部分:
newClientStream:获取传输层 Trasport 并组合封装到 ClientStream 中返回,在这块会涉及负载均衡、超时控制等操作
SendMsg:发送 RPC 请求
RecvMsg:阻塞等待接受到的 RPC 方法响应结果并返回。
版权声明: 本文为 InfoQ 作者【王博】的原创文章。
原文链接:【http://xie.infoq.cn/article/535eba4fa51be30c8a1ddb107】。文章转载请联系作者。
王博
我是一名后端,写代码的憨憨 2018.12.29 加入
还未添加个人简介











评论