RPC Demo(二) 基于 Zookeeper 的服务发现
简介
基于上篇的:RPC Demo(一) Netty RPC Demo 实现
第二部分来实现使用Zookeeper作为服务注册中心,去掉在RPC调用中的显示传参
完整项目工程地址:RpcDemoJava
改进说明
在客户端调用中,我们需要显示的传入后端服务器的地址,这样显的有些不方便,代码大致如下:
UserService userService = jdk.create(UserService.class, "http://localhost:8080/");
利用Zookeeper作为注册中心,客户端可以从Zookeeper中获取接口实现的服务器相关地址,就不必再显式传入地址了,改进后大致如下:
UserService userService = jdk.create(UserService.class);
编码思路
进过调研和思考,实现的思路和步骤大致如下:
1.服务端将Provider注册到Zookeeper中
2.客户端拉取所有的Provider信息到本地,建立接口(Consumer)和Provider列表的映射关系
3.客户端能监听服务端Provider的增删改查,同步到客户端,便于删除和更新变化后的Provider信息
4.客户端反射调用时从Provider列表中获取相关url地址,进行访问,返回结果
需要在本地启动一个zk,使用docker即可,相关命令如下:
# 拉取ZK镜像启动ZK,后面的三个命令是基于运行了这个命令后的docker run -dit --name zk -p 2181:2181 zookeeper# 查看ZK运行日志docker logs -f zk# 重启ZKdocker restart zk# 启动ZKdocker start zk# 停止ZKdocker stop zk
Provider信息结构约定
我们约定一个Provider信息如下:
@Datapublic class ProviderInfo { /** * Provider ID:ZK注册后会生成一个ID * Client 获取Provider列表时,将此ID设置为获取的ZK生成的ID */ String id; /** * Provider对应的后端服务器地址 */ String url; /** * 标签:用于简单路由 */ List<String> tags; /** * 权重:用于加权负载均衡 */ Integer weight; public ProviderInfo() {} public ProviderInfo(String id, String url, List<String> tags, int weight) { this.id = id; this.url = url; this.tags = tags; this.weight = weight; }}
1.服务端将Provider注册到Zookeeper中
首先,我们要为各个接口的实现指定Provider名称、分组、版本、标签、权重,这里我们使用注解进行实现
/** * RPC provider service 初始化注解 * * group,version,targs 都有默认值,是为了兼容以前的版本 * * @author lw1243925457 */@Target(ElementType.TYPE)@Retention(RetentionPolicy.RUNTIME)public @interface ProviderService { /** * 对应 API 接口名称 * @return API service */ String service(); /** * 分组 * @return group */ String group() default "default"; /** * version * @return version */ String version() default "default"; /** * tags:用于简单路由 * 多个标签使用逗号分隔 * @return tags */ String tags() default ""; /** * 权重:用于加权负载均衡 * @return */ int weight() default 1;}
接下来,借鉴Mybatis的设置包扫描路径的思路,写一个通过扫描指定包路径下的所有的class,获取class后判断其是否是Provider(有相应的注解),如果是,提取信息,注册到ZK 中,大致的代码如下:
/** * 提供RPC Provider 的初始化 * 初始化实例放入 Map 中,方便后续的获取 * * @author lw1243925457 */@Slf4jpublic class ProviderServiceManagement { /** * 通过服务名、分组、版本作为key,确实接口实现类的实例 * service:group:version --> class */ private static Map<String, Object> proxyMap = new HashMap<>(); /** * 初始化:通过扫描包路径,获取所有实现类,将其注册到ZK中 * 获取实现类上的Provider注解,获取服务名、分组、版本 * 调用ZK服务注册,将Provider注册到ZK中 * @param packageName 接口实现类的包路径 * @param port 服务监听的端口 * @throws Exception exception */ public static void init(String packageName, int port) throws Exception { System.out.println("\n-------- Loader Rpc Provider class start ----------------------\n"); DiscoveryServer serviceRegister = new DiscoveryServer(); Class[] classes = getClasses(packageName); for (Class c: classes) { ProviderService annotation = (ProviderService) c.getAnnotation(ProviderService.class); if (annotation == null) { continue; } String group = annotation.group(); String version = annotation.version(); List<String> tags = Arrays.asList(annotation.tags().split(",")); String provider = Joiner.on(":").join(annotation.service(), group, version); int weight = annotation.weight(); proxyMap.put(provider, c.newInstance()); serviceRegister.registerService(annotation.service(), group, version, port, tags, weight); log.info("load provider class: " + annotation.service() + ":" + group + ":" + version + " :: " + c.getName()); } System.out.println("\n-------- Loader Rpc Provider class end ----------------------\n"); } /** * Scans all classes accessible from the context class loader which belong to the given package and subpackages. * * @param packageName The base package * @return The classes * @throws ClassNotFoundException exception * @throws IOException exception */ private static Class[] getClasses(String packageName) throws ClassNotFoundException, IOException { ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); assert classLoader != null; String path = packageName.replace('.', '/'); Enumeration<URL> resources = classLoader.getResources(path); List<File> dirs = new ArrayList<>(); while (resources.hasMoreElements()) { URL resource = resources.nextElement(); dirs.add(new File(resource.getFile())); } ArrayList<Class> classes = new ArrayList<>(); for (File directory : dirs) { classes.addAll(findClasses(directory, packageName)); } return classes.toArray(new Class[0]); } /** * Recursive method used to find all classes in a given directory and subdirs. * * @param directory The base directory * @param packageName The package name for classes found inside the base directory * @return The classes * @throws ClassNotFoundException ClassNotFoundException */ private static List<Class> findClasses(File directory, String packageName) throws ClassNotFoundException { List<Class> classes = new ArrayList<>(); if (!directory.exists()) { return classes; } File[] files = directory.listFiles(); assert files != null; for (File file : files) { if (file.isDirectory()) { assert !file.getName().contains("."); classes.addAll(findClasses(file, packageName + "." + file.getName())); } else if (file.getName().endsWith(".class")) { classes.add(Class.forName(packageName + '.' + file.getName().substring(0, file.getName().length() - 6))); } } return classes; }}
接下来该写ZK服务注册的相关代码,这块查查资料就能写出来了,大致如下:
/** * ZK客户端,用于连接ZK * * @author lw1243925457 */@Slf4jpublic class ZookeeperClient { static final String REGISTER_ROOT_PATH = "rpc"; protected CuratorFramework client; ZookeeperClient() { RetryPolicy retryPolicy = new ExponentialBackoffRetry(1000, 3); this.client = CuratorFrameworkFactory.builder() .connectString("localhost:2181") .namespace(REGISTER_ROOT_PATH) .retryPolicy(retryPolicy) .build(); this.client.start(); log.info("zookeeper service register init"); }}/** * 服务发现服务器:用于注册Provider * * @author lw1243925457 */public class DiscoveryServer extends ZookeeperClient { private List<ServiceDiscovery<ProviderInfo>> discoveryList = new ArrayList<>(); public DiscoveryServer() { } /** * 生成Provider的相关信息,注册到ZK中 * @param service Service impl name * @param group group * @param version version * @param port service listen port * @param tags route tags * @param weight load balance weight * @throws Exception exception */ public void registerService(String service, String group, String version, int port, List<String> tags, int weight) throws Exception { ProviderInfo provider = new ProviderInfo(null, null, tags, weight); ServiceInstance<ProviderInfo> instance = ServiceInstance.<ProviderInfo>builder() .name(Joiner.on(":").join(service, group, version)) .port(port) .address(InetAddress.getLocalHost().getHostAddress()) .payload(provider) .build(); JsonInstanceSerializer<ProviderInfo> serializer = new JsonInstanceSerializer<>(ProviderInfo.class); ServiceDiscovery<ProviderInfo> discovery = ServiceDiscoveryBuilder.builder(ProviderInfo.class) .client(client) .basePath(REGISTER_ROOT_PATH) .thisInstance(instance) .serializer(serializer) .build(); discovery.start(); discoveryList.add(discovery); } public void close() throws IOException { for (ServiceDiscovery<ProviderInfo> discovery: discoveryList) { discovery.close(); } client.close(); }}
到这,服务端的核心代码基本写完了,给接口实现类加上相应的注解,启动服务器即可:
/** * @author lw */@ProviderService(service = "com.rpc.demo.service.UserService", group = "group2", version = "v2", tags = "tag2")public class UserServiceV2Impl implements UserService { @Override public User findById(Integer id) { return new User(id, "RPC group2 v2"); }}public class ServerApplication { public static void main(String[] args) throws Exception { BackListFilter.addBackAddress("172.21.16.1"); final int port = 8080; ProviderServiceManagement.init("com.rpc.server.demo.service.impl", port); final RpcNettyServer rpcNettyServer = new RpcNettyServer(port); try { rpcNettyServer.run(); } catch (Exception e) { e.printStackTrace(); } finally { rpcNettyServer.destroy(); } }}
2.客户端相应代码编写
2.客户端拉取所有的Provider信息到本地,建立接口(Consumer)和Provider列表的映射关系
3.客户端能监听服务端Provider的增删改查,同步到客户端,便于删除和更新变化后的Provider信息
4.客户端反射调用时从Provider列表中获取相关url地址,进行访问,返回结果
上面都是客户端需要增加的功能,我们直接写一个服务发现客户端,在其中实现相关的功能,大致代码如下:
/** * 服务发现客户端 * 获取Provider列表 * 监听Provider更新 * 查找返回接口的Provider(先tag路由,后负载均衡) * * @author lw1243925457 */@Slf4jpublic class DiscoveryClient extends ZookeeperClient { private enum EnumSingleton { /** * 懒汉枚举单例 */ INSTANCE; private DiscoveryClient instance; EnumSingleton(){ instance = new DiscoveryClient(); } public DiscoveryClient getSingleton(){ return instance; } } public static DiscoveryClient getInstance(){ return EnumSingleton.INSTANCE.getSingleton(); } /** * Provider缓存列表 * server:group:version -> provider instance list */ private Map<String, List<ProviderInfo>> providersCache = new HashMap<>(); private final ServiceDiscovery<ProviderInfo> serviceDiscovery; private final CuratorCache resourcesCache; private LoadBalance balance = new WeightBalance(); private DiscoveryClient() { serviceDiscovery = ServiceDiscoveryBuilder.builder(ProviderInfo.class) .client(client) .basePath("/" + REGISTER_ROOT_PATH) .build(); try { serviceDiscovery.start(); } catch (Exception e) { e.printStackTrace(); } try { getAllProviders(); } catch (Exception e) { e.printStackTrace(); } this.resourcesCache = CuratorCache.build(this.client, "/"); watchResources(); if (RpcClient.getBalanceAlgorithmName().equals(WeightBalance.NAME)) { this.balance = new WeightBalance(); } else if (RpcClient.getBalanceAlgorithmName().equals(ConsistentHashBalance.NAME)) { this.balance = new ConsistentHashBalance(); } } /** * 从ZK中获取所有的Provider列表,保存下来 * @throws Exception exception */ private void getAllProviders() throws Exception { System.out.println("\n\n======================= init : get all provider"); Collection<String> serviceNames = serviceDiscovery.queryForNames(); System.out.println(serviceNames.size() + " type(s)"); for ( String serviceName : serviceNames ) { Collection<ServiceInstance<ProviderInfo>> instances = serviceDiscovery.queryForInstances(serviceName); System.out.println(serviceName); for ( ServiceInstance<ProviderInfo> instance : instances ) { System.out.println(instance.toString()); String url = "http://" + instance.getAddress() + ":" + instance.getPort(); ProviderInfo providerInfo = instance.getPayload(); providerInfo.setId(instance.getId()); providerInfo.setUrl(url); List<ProviderInfo> providerList = providersCache.getOrDefault(instance.getName(), new ArrayList<>()); providerList.add(providerInfo); providersCache.put(instance.getName(), providerList); System.out.println("add provider: " + instance.toString()); } } System.out.println(); for(String key: providersCache.keySet()) { System.out.println(key + " : " + providersCache.get(key)); } System.out.println("======================= init : get all provider end\n\n"); } /** * 根据传入的接口名称、分组、版本,返回讲过tag路由,负载均衡后的一个Provider服务器地址 * @param service service name * @param group group * @param version version * @param tags tags * @param methodName method name * @return provider host ip */ public String getProviders(String service, String group, String version, List<String> tags, String methodName) { String provider = Joiner.on(":").join(service, group, version); if (!providersCache.containsKey(provider) || providersCache.get(provider).isEmpty()) { return null; } List<ProviderInfo> providers = FilterLine.filter(providersCache.get(provider), tags); if (providers.isEmpty()) { return null; } return balance.select(providers, service, methodName); } /** * 监听Provider的更新 */ private void watchResources() { CuratorCacheListener listener = CuratorCacheListener.builder() .forCreates(this::addHandler) .forChanges(this::changeHandler) .forDeletes(this::deleteHandler) .forInitialized(() -> log.info("Resources Cache initialized")) .build(); resourcesCache.listenable().addListener(listener); resourcesCache.start(); } /** * 增加Provider * @param node new provider */ private void addHandler(ChildData node) { System.out.println("\n\n=================== add new provider ============================"); System.out.printf("Node created: [%s:%s]%n", node.getPath(), new String(node.getData())); if (providerDataEmpty(node)) { return; } updateProvider(node); System.out.println("=================== add new provider end ============================\n\n"); } /** * Provider更新 * @param oldNode old provider * @param newNode updated provider */ private void changeHandler(ChildData oldNode, ChildData newNode) { System.out.printf("Node changed, Old: [%s: %s] New: [%s: %s]%n", oldNode.getPath(), new String(oldNode.getData()), newNode.getPath(), new String(newNode.getData())); if (providerDataEmpty(newNode)) { return; } updateProvider(newNode); } /** * 增加或更新本地Provider * @param newNode updated provider */ private void updateProvider(ChildData newNode) { String jsonValue = new String(newNode.getData(), StandardCharsets.UTF_8); JSONObject instance = (JSONObject) JSONObject.parse(jsonValue); System.out.println(instance.toString()); String url = "http://" + instance.get("address") + ":" + instance.get("port"); ProviderInfo providerInfo = JSON.parseObject(instance.get("payload").toString(), ProviderInfo.class); providerInfo.setId(instance.get("id").toString()); providerInfo.setUrl(url); List<ProviderInfo> providerList = providersCache.getOrDefault(instance.get("name").toString(), new ArrayList<>()); providerList.add(providerInfo); providersCache.put(instance.get("name").toString(), providerList); } /** * 删除Provider * @param oldNode provider */ private void deleteHandler(ChildData oldNode) { System.out.println("\n\n=================== delete provider ============================"); System.out.printf("Node deleted, Old value: [%s: %s]%n", oldNode.getPath(), new String(oldNode.getData())); if (providerDataEmpty(oldNode)) { return; } String jsonValue = new String(oldNode.getData(), StandardCharsets.UTF_8); JSONObject instance = (JSONObject) JSONObject.parse(jsonValue); System.out.println(instance.toString()); String provider = instance.get("name").toString(); int deleteIndex = -1; for (int i = 0; i < providersCache.get(provider).size(); i++) { if (providersCache.get(provider).get(i).getId().equals(instance.get("id").toString())) { deleteIndex = i; break; } } if (deleteIndex != -1) { providersCache.get(provider).remove(deleteIndex); } System.out.println("=================== delete provider end ============================\n\n"); } private boolean providerDataEmpty(ChildData node) { return node.getData().length == 0; } public synchronized void close() { client.close(); }}
看着有点多,但不是太复杂,理清思路自己也能写出来
接下来是代理请求的修改,在:RpcInvocationHandler,中去掉显式的url传参,改为url从DiscoveryClient中获取,大致如下:
public class RpcInvocationHandler implements InvocationHandler, MethodInterceptor { /** * 发送请求到服务端 * 获取结果后序列号成对象,返回 * @param service service name * @param method service method * @param params method params * @return object */ private Object process(Class<?> service, Method method, Object[] params) { log.info("Client proxy instance method invoke"); // 自定义了Rpc请求的结构 RpcRequest,放入接口名称、方法名、参数 log.info("Build Rpc request"); RpcRequest rpcRequest = new RpcRequest(); rpcRequest.setServiceClass(service.getName()); rpcRequest.setMethod(method.getName()); rpcRequest.setArgv(params); rpcRequest.setGroup(group); rpcRequest.setVersion(version); // 从DiscoveryClient中获取某个Provider的请求地址 String url = null; try { url = discoveryClient.getProviders(service.getName(), group, version, tags, method.getName()); } catch (Exception e) { e.printStackTrace(); } if (url == null) { System.out.println("\nCan't find provider\n"); return null; } // 客户端使用的 netty,发送请求到服务端,拿到结果(自定义结构:rpcfxResponse) log.info("Client send request to Server"); RpcResponse rpcResponse; try { rpcResponse = RpcNettyClientSync.getInstance().getResponse(rpcRequest, url); } catch (InterruptedException | URISyntaxException e) { e.printStackTrace(); return null; } log.info("Client receive response Object"); assert rpcResponse != null; if (!rpcResponse.getStatus()) { log.info("Client receive exception"); rpcResponse.getException().printStackTrace(); return null; } // 序列化成对象返回 log.info("Response:: " + rpcResponse.getResult()); return JSON.parse(rpcResponse.getResult().toString()); }}
客户端代码也是去掉url,更加简洁,大致如下:
public class ClientApplication { public static void main(String[] args) { // fastjson auto setting ParserConfig.getGlobalInstance().addAccept("com.rpc.demo.model.Order"); ParserConfig.getGlobalInstance().addAccept("com.rpc.demo.model.User"); RpcClient client = new RpcClient(); RpcClient.setBalanceAlgorithmName(ConsistentHashBalance.NAME); UserService userService = client.create(UserService.class, "group2", "v2"); User user = userService.findById(1); if (user == null) { log.info("Clint service invoke Error"); } else { System.out.println("\n\nuser1 :: find user id=1 from server: " + user.getName()); } }}
版权声明: 本文为 InfoQ 作者【萧】的原创文章。
原文链接:【http://xie.infoq.cn/article/a65040df65f7ac85a0ae3aff7】。
本文遵守【CC-BY 4.0】协议,转载请保留原文出处及本版权声明。
萧
还未添加个人签名 2018.09.09 加入
代码是门手艺活,也是门艺术活
评论