用一致性 Hash 算法的实现负载均衡(Kotlin)

用户头像
Acker飏
关注
发布于: 2020 年 07 月 06 日

场景描述

假设公司有10台服务器,它们用来缓存Key-Value数据,目前有100万个这样的数据需要缓存,我们希望这些数据能够尽量均匀的缓存到这10台服务器上,以便分摊服务器压力。

常规的方式

常规的方式一般通过Hash取模运算,将请求发送到不同的对应服务器上,如下图:

这个算法非常简单,而且数据分散均匀,但这样的方式会存在一个问题,当增减机器的时候,服务器数量改变会导致,取模后的数值发生改变从而使得缓存数据失效,命中率大大降低。



一致性哈希算法:

原理:

  1. 将整个哈希值区间组织成一个圆环,其区间为,也就是一个32位的无符号整形数据组成的圆环。

  2. 计算每台服务器的Hash值(一般使用服务器名字,IP地址,端口号作为Key来进行计算),然后把计算得到的值映射到环中去,以3台服务器为例,如下图:



  1. 使用相同的Hash算法计算需要缓存的数据的Hash值,并把这些数据映射到环中去。

  2. 那么在这个环中,数据需要存储的服务器是沿着顺时针方向遇到的第一个服务器。



  1. 添加或减少节点的时候,相比与取模的方式会导致大量缓存失效,一致性Hash受影响的数据仅仅是新服务器在环上逆时针后面的第一个服务器之间的数据。不至于整个平台的缓存压力仅存在于少数服务器上。



由此可见使用一致性Hash算法,有效的避免了哈希取模的方式的弊端,在一定程度上避免了增减服务器带来的瞬间缓存压力,但它也存在一个问题,那就是Hash环偏斜,因为服务器的Key计算出来的Hash值映射在环上并不是均匀的,例如下面的情况:

带虚拟节点的一致性Hash

为了解决上述方式中Hash环偏斜的问题,我们可以引入虚拟节点(Virtual Node),这些节点是实际物理节点(Physical Node)在Hash环上的副本(Replica),一个物理节点可以对应多个副本,并作为虚拟节点映射在Hash环上,例如下图:

这样通过虚拟节点的添加,整个hash环上的节点就较为均匀了,如果蓝色的节点挂掉或者删除了,缓存压力会被顺时针转移到比邻的下一个虚拟节点上,而且这些下一个虚拟节点很可能分别是不同物理服务器的副本,所以这些因蓝色节点丢失导致的压力也不会集中某一台服务器上。



一致性Hash算法的具体实现

在实现之前我们首先要考虑的问题是数据结构与算法的选取,要将区间为内的所有整型数据分布在环上使用何种数据结构,时间复杂度才最低呢。



首先我们先分析下这个缓存的业务:

  1. 查找节点,计算出缓存的Hash之后,需要在这个数据结构中找到比这个值大的第一个Hash值对应的虚拟节点即可得到我们需要的真实物理节点,这是我们最常用的操作。

  2. 删除节点,需要在这个数据结构中查找找到对应的虚拟节点并删除数据。

  3. 增加节点,计算出节点的Hash值后,需要在这个数据结构中插入这些数据。



从上面的描述中我们可以得出,如果要提升效率我们要主要两个方面:

  1. 数据结构:需要满足有序数据的存储,查找的时间复杂度要足够的低。

对于这种要求来说二叉查找树是个不错的选择,但二叉查找树也会出现偏斜问题,出现不平衡的情况,所以要选择一种自平衡二叉查找树,这里我们采用红黑树来实现,而且JDK中的TreeMap就是以红黑树的方式实现的。对红黑树不了解的同学可以看看这篇漫画算法:5分钟搞明白红黑树到底是什么?

但红黑树也有一个缺点,就是插入数据的代价比较高,因为这是要维护红黑树的自平衡带来的消耗。



  1. 计算Hash值的算法:尽可能的散列,碰撞度低,效率高。

基于这个要求,String自带的hashCode()方法就无法满足这个条件。

fun main() {
println("node:192.168.1.1:123".hashCode().absoluteValue)
println("node:192.168.1.2:123".hashCode().absoluteValue)
println("node:192.168.1.3:123".hashCode().absoluteValue)
println("node:192.168.1.4:123".hashCode().absoluteValue)
println("node:192.168.1.5:123".hashCode().absoluteValue)
}

运行结果:

1919693013
1920616534
1921540055
1922463576
1923387097

Process finished with exit code 0

可见散列程度过低,会使得Hash环严重偏斜,所以我们需要重写Hash算法。



最后我选了Google的CityHash算法,有较高的效率以及散列程度,而且这个算法不适用于加密,但适合用在散列表,正适合于当前的情况。

这是使用CityHash之后的结果:

2122578006
1764325874
1330746139
492796170
1581359175

Process finished with exit code 0



interface Hashable {
fun hash32(raw: String): Int
}
/**
* 迁移cityhash的官方c版本,
*/
class CityHash: Hashable {
override fun hash32(raw: String): Int {
val byteArray = convertString2UTF8(raw)
val len = byteArray.size
if (len <= 24) {
return if (len <= 12) if (len <= 4) hash32Len0to4(byteArray) else hash32Len5to12(byteArray) else hash32Len13to24(
byteArray
)
}
// len > 24
var h = len
var g = c1 * len
var f = g
var a0 =
rotate32(
fetch32(
byteArray,
len - 4
) * c1, 17
) * c2
var a1 =
rotate32(
fetch32(
byteArray,
len - 8
) * c1, 17
) * c2
var a2 = rotate32(
fetch32(byteArray, len - 16) * c1,
17
) * c2
var a3 = rotate32(
fetch32(byteArray, len - 12) * c1,
17
) * c2
var a4 = rotate32(
fetch32(byteArray, len - 20) * c1,
17
) * c2
h = h xor a0
h = rotate32(h, 19)
h = h * 5 + -0x19ab949c
h = h xor a2
h = rotate32(h, 19)
h = h * 5 + -0x19ab949c
g = g xor a1
g = rotate32(g, 19)
g = g * 5 + -0x19ab949c
g = g xor a3
g = rotate32(g, 19)
g = g * 5 + -0x19ab949c
f += a4
f = rotate32(f, 19)
f = f * 5 + -0x19ab949c
var iters = (len - 1) / 20
var pos = 0
do {
a0 =
rotate32(
fetch32(
byteArray,
pos
) * c1, 17
) * c2
a1 = fetch32(byteArray, pos + 4)
a2 = rotate32(
fetch32(byteArray, pos + 8) * c1,
17
) * c2
a3 = rotate32(
fetch32(byteArray, pos + 12) * c1,
17
) * c2
a4 = fetch32(byteArray, pos + 16)
h = h xor a0
h = rotate32(h, 18)
h = h * 5 + -0x19ab949c
f += a1
f = rotate32(f, 19)
f = f * c1
g += a2
g = rotate32(g, 18)
g = g * 5 + -0x19ab949c
h = h xor a3 + a1
h = rotate32(h, 19)
h = h * 5 + -0x19ab949c
g = g xor a4
g = Integer.reverseBytes(g) * 5
h += a4 * 5
h = Integer.reverseBytes(h)
f += a0
val swapValue = f
f = g
g = h
h = swapValue
pos += 20
} while (--iters != 0)
g = rotate32(
g,
11
) * c1
g = rotate32(
g,
17
) * c1
f = rotate32(
f,
11
) * c1
f = rotate32(
f,
17
) * c1
h = rotate32(h + g, 19)
h = h * 5 + -0x19ab949c
h = rotate32(
h,
17
) * c1
h = rotate32(h + f, 19)
h = h * 5 + -0x19ab949c
h = rotate32(
h,
17
) * c1
return h
}
private fun hash32Len0to4(byteArray: ByteArray): Int {
var b = 0
var c = 9
val len = byteArray.size
for (i in 0 until len) {
val v = byteArray[i].toInt()
b = b * c1 + v
c = c xor b
}
return fmix(mur(b, mur(len, c)))
}
private fun hash32Len5to12(byteArray: ByteArray): Int {
val len = byteArray.size
var a = len
var b = len * 5
var c = 9
val d = b
a += fetch32(byteArray, 0)
b += fetch32(byteArray, len - 4)
c += fetch32(byteArray, len ushr 1 and 4)
return fmix(mur(c, mur(b, mur(a, d))))
}
private fun hash32Len13to24(byteArray: ByteArray): Int {
val len = byteArray.size
val a = fetch32(byteArray, (len ushr 1) - 4)
val b = fetch32(byteArray, 4)
val c = fetch32(byteArray, len - 8)
val d = fetch32(byteArray, len ushr 1)
val e = fetch32(byteArray, 0)
val f = fetch32(byteArray, len - 4)
return fmix(mur(f, mur(e, mur(d, mur(c, mur(b, mur(a, len)))))))
}
private fun loadUnaligned32(byteArray: ByteArray, start: Int): Int {
var result = 0
val orderIter = OrderIter(
4,
IS_BIG_EDIAN
)
while (orderIter.hasNext()) {
val next = orderIter.next()
val value: Int = byteArray[next + start].toInt() and 0xff shl next * 8
result = result or value
}
return result
}
private fun fetch32(byteArray: ByteArray, start: Int): Int {
return loadUnaligned32(byteArray, start)
}
private fun fmix(h: Int): Int {
var h = h
h = h xor (h ushr 16)
h *= -0x7a143595
h = h xor (h ushr 13)
h *= -0x3d4d51cb
h = h xor (h ushr 16)
return h
}
private fun mur(a: Int, h: Int): Int {
// Helper from Murmur3 for combining two 32-bit values.
var a = a
var h = h
a *= c1
a = rotate32(a, 17)
a *= c2
h = h xor a
h = rotate32(h, 19)
return h * 5 + -0x19ab949c
}
private class OrderIter internal constructor(private val size: Int, private val isBigEdian: Boolean) {
private var index = 0
operator fun hasNext(): Boolean {
return index < size
}
operator fun next(): Int {
return if (!isBigEdian) {
index++
} else {
size - 1 - index++
}
}
}
private fun convertString2UTF8(raw: String?): ByteArray {
return try {
raw!!.toByteArray(charset("UTF-8"))
} catch (e: UnsupportedEncodingException) {
// don't happen
throw RuntimeException(e)
}
}
companion object {
private val IS_BIG_EDIAN = "little" != System.getProperty("sun.cpu.endian")
// Magic numbers for 32-bit hashing. Copied from Murmur3.
const val c1 = -0x3361d2af
const val c2 = 0x1b873593
fun rotate32(`val`: Int, shift: Int): Int {
// Avoid shifting by 32: doing so yields an undefined result.
return if (shift == 0) `val` else `val` ushr shift or (`val` shl 32 - shift)
}
}
}



一致性Hash算法的Kotlin实现

/**
* 物理节点
*/
interface PhysicalNode {
fun hashKey(): String
fun put(key: String, value: String): String?
fun get(key: String): String?
fun size(): Int
fun clear()
fun remove(key: String): String?
}
/**
* 虚拟节点
*
* @param physicalNode 物理节点
* @param replica 虚拟节点id
*/
data class VirtualNode<out T : PhysicalNode>(val physicalNode: T, private val replica: Int) : PhysicalNode {
override fun hashKey() = "${physicalNode.hashKey()}#$replica"
override fun put(key: String, value: String): String? = null
override fun get(key: String): String? = null
override fun size(): Int = 0
override fun clear() {}
override fun remove(key: String): String? = null
}
/**
* 常规的服务节点
*
* @param name 服务名称
* @param host 服务host
* @param port 服务port
*/
data class NormalPhysicalNode(val name: String, val host: String, val port: Int) : PhysicalNode {
private val caches = hashMapOf<String,String>()
override fun hashKey() = "$name:$host:$port"
override fun put(key: String, value: String) = caches.put(key, value)
override fun get(key: String) = caches[key]
override fun size() = caches.size
override fun clear() = caches.clear()
override fun remove(key: String) = caches.remove(key)
}



/**
* 一致性哈希算法实现
*/
class ConsistentHash<T : PhysicalNode>(private var hashable: Hashable) {
private var ring = TreeMap<Int, VirtualNode<T>>()
private var physicalNodes = mutableListOf<T>()
/**
* 添加物理节点,如果节点存在则为其添加副本
* @param node 物理节点
* @param replicas 副本节点个数,默认为1
*/
fun add(node: T, replicas: Int = 1) {
// 添加物理
physicalNodes.add(node)
// 获得该物理节点已有的副本数量
val existingReplicas = ring.values
.filter {
it.physicalNode.hashKey() == node.hashKey()
}.size
// 根据副本数量增加虚拟节点
(0..(if (replicas < 1) 0 else replicas - 1)).forEach {
val vNode = VirtualNode(node, existingReplicas + it)
ring[vNode.hashKey().hash32(hashable)] = vNode
}
}
/**
* 移除物理节点下所有虚拟节点
* @param key 节点Key
*/
fun remove(key: String) {
physicalNodes.removeIf { it.hashKey() == key }
ring.keys
.filter { ring[it]?.physicalNode?.hashKey() == key }
.forEach { ring.remove(it) }
}
/**
* 获得所有虚拟节点
*/
fun getAllVirtualNode() = ring.values
/**
* 获得物理节点数量
*/
fun getPhysicalNodeNum() = physicalNodes.size
/**
* 获得物理节点缓存情况
*/
fun getLoadInfo() = buildList {
physicalNodes.forEach {
// println("$it - size: ${it.size()}")
add(it.size().toDouble()) //获得每个服务器的缓存数量
}
}
/**
* 获得负载情况
*/
fun getLoadStdDeviation() = getLoadInfo().stdDeviation()
/**
* 根据key获取物理节点
*/
fun getNode(key: String): T? {
if (ring.isEmpty()) return null
// 获取Key的Hash值并以此获得大于该Hash值的所有Map
val tail = ring.tailMap(key.hash32(hashable))
// 第一个Key就是顺时针过去离node最近的那个节点,如果是空集则取环上第一个节点
return ring[if (tail.isEmpty()) ring.firstKey() else tail.firstKey()]?.physicalNode
}
}

这里我采用DSL的方式构建一致性HASH:

/**
* 一致性哈希构建函数(DSL)
*/
inline fun <T: PhysicalNode> buildConsistentHash(buildAction : ConsistentHashBuilder<T>.() -> Unit): ConsistentHash<T> {
val builder = ConsistentHashBuilder<T>()
builder.buildAction()
val consistentHash = ConsistentHash<T>(builder.hashable)
builder.nodes.forEach {
consistentHash.add(it, builder.replicas)
}
return consistentHash
}
/**
* 一致性哈希构建类
*/
class ConsistentHashBuilder<T : PhysicalNode>(
var hashable: Hashable = CityHash(),
var nodes : Collection<T> = mutableListOf(),
var replicas: Int = 1) {
/**
* 初始化节点
*/
fun nodes(replicas: Int = 1, builderAction: MutableList<T>.() -> Unit): ConsistentHashBuilder<T> {
this.nodes = ArrayList<T>().apply(builderAction)
this.replicas = replicas
return this
}
}

例子, 两种方式:

fun main() {
val consistentHash =
buildConsistentHash<NormalPhysicalNode> {
//构建节点,指定副本数量
nodes(100) {
//添加3个物理节点
add(NormalPhysicalNode("Node1", "192.169.1.1", 8080))
add(NormalPhysicalNode("Node2", "192.169.1.2", 8080))
add(NormalPhysicalNode("Node3", "192.169.1.3", 8080))
}
}
val consistentHash1 =
buildConsistentHash<NormalPhysicalNode> {
// 指定Hash算法
hashable = CityHash()
// 指定副本数
replicas = 100
// 添加物理节点
nodes = buildList {
//添加3个物理节点
add(NormalPhysicalNode("Node1", "192.169.1.1", 8080))
add(NormalPhysicalNode("Node2", "192.169.1.2", 8080))
add(NormalPhysicalNode("Node3", "192.169.1.3", 8080))
}
}
}



一些工具扩展函数:

/**
* 字符串获取Hash值扩展函数
*/
fun String.hash32(hashable: Hashable) = hashable.hash32(this).absoluteValue
/**
* List计算标准差扩展函数
*/
fun List<Double>.stdDeviation(): Double {
val m = this.size
// 获得平均值
val dAve = this.average()
var dVar = 0.0
//求方差
(0 until m).forEach {
dVar += (this[it] - dAve) * (this[it] - dAve)
}
//return sqrt(dVar/(m-1)); //样本标准差
return sqrt(dVar / m)
}
/**
* 获取指定范围及长度的随机字符串
*/
fun ClosedRange<Char>.randomString(length: Int) =
(1..length)
.map { (Random().nextInt(endInclusive.toInt() - start.toInt()) + start.toInt()).toChar() }
.joinToString("")

一致性Hash算法测试

class ConsistentHashTest {
@Test
fun `with virtual node`() {
// 将数据写入CSV方便统计分析
csvWriter().open("test.csv") {
writeRow("物理节点数", "副本数量", "虚拟节点数", "耗时", "标准差")
}
// 准备100万KV数据
val kvList = mutableListOf<Pair<String, String>>()
repeat(1000000) {
val kv = ('A' .. 'z').randomString(10) to "v1"
kvList.add(kv)
}
// 测试单个服务器的副本数从1到350的性能
repeat(350){ count ->
// 构建一致性哈希
val consistentHash =
buildConsistentHash<NormalPhysicalNode> {
//构建节点,指定副本数量
nodes(replicas = count + 1) {
//添加10个物理节点
(1 .. 10).forEach {
add(NormalPhysicalNode("Node$it", "192.169.1.$it", 8080))
}
}
}
// 计算物理节点查询总耗时
val s = System.currentTimeMillis()
kvList.forEach { kv ->
// 查询物理节点并进行缓存
consistentHash.getNode(kv.first)?.put(kv.first,kv.second)
}
val duration = System.currentTimeMillis() - s
val physicalNodeNum = consistentHash.getPhysicalNodeNum()
val replica = count + 1
val virtualNodeNum = consistentHash.getAllVirtualNode().size
val cacheSize = kvList.size
val loadStdDeviation = consistentHash.getLoadStdDeviation()
println("物理节点数:$physicalNodeNum 副本倍数:$replica 虚拟节点数:$virtualNodeNum" +
" 缓存数:$cacheSize 耗时:${duration}ms 标准差:$loadStdDeviation")
csvWriter().open("test.csv", append = true) {
// ("物理节点数", "副本数量", "虚拟节点数", "耗时", "标准差")
writeRow(physicalNodeNum, replica, virtualNodeNum, duration, loadStdDeviation)
}
}
}
}
运行结果:

以10个服务器节点,每个物理服务器从1到350个副本,存储100万数据的条件下,将测试运行导出数据后进行统计分析得到:

  1. 随着副本数的增加,表达负责均衡度的标准差性急剧减小,从多项式趋势线可以看出,到50倍副本时方差趋于稳定,在200-250倍之间获得方差获得最低值,标准差大致分布在4000左右。

  2. 随着副本的增多耗时稍有增加,但不明显,到200倍后趋于稳定,在700ms左右。



从以上分析得出,在当前算法下,10台物理服务器,采用200倍左右副本为最佳实践



扩展

为了对比CityHash,我又使用了一下一些常用的Hash算法进行对比:

Java自带Hash, APHash, DEKHash, DJBHash, SDBMHash, BKDRHash, ELFHash, PJWHash, JSHash, RSHash, FNVHash, BernsteinHash, OneByOneHash

最终标准差能低于10000的只有CityHash,FNVHash, OneByOneHash, 其中CityHash的标准差最低能达到3500左右,其他两种只能到5000左右,目前看来CityHash还是最优选择。



发布于: 2020 年 07 月 06 日 阅读数: 227
用户头像

Acker飏

关注

还未添加个人签名 2018.05.03 加入

还未添加个人简介

评论

发布
暂无评论
用一致性Hash算法的实现负载均衡(Kotlin)