一致性 HASH 的 golang 实现
什么是hash
将一个无限的定义内容,映射到有限的范围的一种手段。这个比较难以理解,举个栗子:将一个很长(无限长)的字符串映射到一个int32的数字;将一个byte数组映射到32byte的数组。通过映射之后,就可以用这个有限的内容来标识无限的内容。通常来说,有限的内容都比较短小,如一个32位的整数,一个32字节的字符串。有了这个较短的标识,通常可以用来快速进行比较两个内容是否相等。
为什么可以用一个较短的内容标识一个较长的内容?
假设一个100个英文字符的长度,总共有26的100次方种可能的内容,这是一个天文数字的可能性。但是实际的应用中,100个英文字符只表示一段话,或者一段json数据,这样的可能性就小了很多很多很多,所以就有可能用32位的整数来表示这个些可能性。
什么是hash函数
通过一个函数,把一段较大的内容映射到较小内容的函数就是hash函数。比如java中的Object.hashcode(),把一个对象映射成一个32位的整数。
所以的,就有可能会把两个不同的内容映射成为了统一个整数。
Hash在哪些方面应用
HashMap中会把对象根据内容映射成为一个int整数。
memcache分布式中会把key映射到不同的服务器节点。
数据库分表,根据用户id映射到不同的表中。
一致性hash解决什么问题
举个栗子,把缓存的key映射到不同的节点上,这样一个好的hash函数,在比较多的key时候,基本保证每个节点存储的key是大约数量相等的。
问题:如果服务器的节点数发生变化,所有key的和节点的映射关系就发生变化,导致所有的key都需要重新映射,这个计算成本太高了。为了减少在服务器节点发生变化重新映射的key的数量,引入一致性hash。
一致性hash的基本映射原理,重要概念
- 在一个环上有2^32-1个点, 
- 在这个环上放N个虚拟节点,每个节点对应一个真实的节点;通常有N个虚拟节点对应同一个真实节点。N为数百个。 
- 任何一个key,都可以通过hash函数映射成为uint32的整数p,该整数p对应环上的一个点;顺时针方式可以需找到最近一个虚拟节点;从而找到所对应的真实节点。 
重要概念:环,虚拟节点,真实节点,hash函数。
代码实现(golang)
package consistency_hashimport (	"errors"	"fmt"	"github.com/emirpasic/gods/maps/treemap"	hash2 "hash"	"hash/fnv"	"math")func init() {	hash = fnv.New32()}type NodeExistError stringfunc (n *NodeExistError) Error() string {	return string(*n)}type ConsistencyHash interface {	GetNode(string) *Node	AddNode(*Node) error}type Node struct {	Name string}type VirtualNode struct {	hash uint32	name string	node *Node	between [2]uint32}func (v *VirtualNode)String() string {	return fmt.Sprintf("%s: %d - %v", v.name, v.hash, v.between)}/**一致性hash 的定义*/type conHashImpl struct {	vNodesPerNode int	ring          Ring	nodes         map[string]*Node}type Ring struct {	VirtualNodes  []*VirtualNode}func (r *Ring) ceiling(hash uint32) int32 {	if len(r.VirtualNodes) == 0 {		return -1	}	s, e := 0, len(r.VirtualNodes) - 1	pos := (s + e)/2	for pos >= 0 && s < e {		if r.VirtualNodes[pos].between[0] <= hash && r.VirtualNodes[pos].between[1] >= hash {			break;		}		if r.VirtualNodes[pos].between[0] > hash {			e = pos - 1		} else {			s = pos + 1		}		pos = (s + e) / 2	}	if hash > r.VirtualNodes[pos].between[1] || hash < r.VirtualNodes[pos].between[0] {		return -1	}	return int32(pos)}// add node to ringfunc (r *Ring)addNode(vnode VirtualNode) error {	if len(r.VirtualNodes) >= math.MaxInt32 - 1 {		return errors.New(fmt.Sprintf("too much nodes. %d", len(r.VirtualNodes)))	}	targetPos := r.ceiling(vnode.hash)	if targetPos < 0 {		begin := uint32(0)		if len(r.VirtualNodes) > 0 {			begin = r.VirtualNodes[len(r.VirtualNodes) - 1].between[1] + 1		}		r.VirtualNodes = append(r.VirtualNodes, &VirtualNode{			name: vnode.name,			node: vnode.node,			hash: vnode.hash,			between: [2]uint32{begin, vnode.hash},		})	} else {		targetNode := r.VirtualNodes[targetPos]		var additionNodes []*VirtualNode		if vnode.hash == targetNode.between[1] {			return errors.New("same hash")		}		if vnode.hash < targetNode.hash {			additionNodes = []*VirtualNode {{				name:    vnode.name,				node:    vnode.node,				hash:    vnode.hash,				between: [2]uint32{targetNode.between[0], vnode.hash},			}, {				name:    targetNode.name,				node:    targetNode.node,				hash:    targetNode.hash,				between: [2]uint32{vnode.hash+1, targetNode.between[1]},			}}		} else {			additionNodes = []*VirtualNode {targetNode, {				name:    vnode.name,				node:    vnode.node,				hash:    vnode.hash,				between: [2]uint32{targetNode.between[1]+1, vnode.hash},			}}		}		//replace the target node to additional nodes.		if targetPos < int32(len(r.VirtualNodes) - 1 ) {			// new list is  append([0:targetpos], additional[2], [targetpos+1:])			r.VirtualNodes = append(r.VirtualNodes[:targetPos], append(additionNodes, r.VirtualNodes[targetPos + 1:]...)...)		} else {			r.VirtualNodes = append(r.VirtualNodes[:targetPos], additionNodes...)		}	}	return nil}// add node to ringfunc (r *Ring) nextNodeFrom(hash uint32) *VirtualNode {	pos := r.ceiling(hash)	if pos < 0 {		return r.VirtualNodes[0]	} else {		node := r.VirtualNodes[pos]		return node	}}//Get a hash codevar hash hash2.Hash32func hashCode(key string) uint32 {	hash.Write([]byte(fmt.Sprintf("%s%s---", key, key)))	return hash.Sum32()}//add a nodefunc (hash *conHashImpl) AddNode(node *Node) error {	if _, existed := hash.nodes[node.Name]; existed {		err := NodeExistError(fmt.Sprintf("node name(%s) existed", node.Name))		return &err	}	for i := 0; i < hash.vNodesPerNode; i++ {		vName := fmt.Sprintf("%s.%d", node.Name, i)		vNode := VirtualNode{hash: hashCode(vName),  name: vName, node: node}		err := hash.ring.addNode(vNode)		if err != nil {			fmt.Println("error to add node ", err)		}	}	return nil}//Get the node with key  GetNode(key string) *Nodefunc (hash *conHashImpl) GetNode(key string) *Node {	vNode := hash.ring.nextNodeFrom(hashCode(key))	if vNode == nil {		return nil	}	return vNode.node}//Create new Instance of Consistency Hash with give nodes.func NewConsistencyHash(factor int, nodes ...*Node) (ConsistencyHash, error) {	impl := &conHashImpl{vNodesPerNode: factor}	for _, node := range nodes {		err := impl.AddNode(node)		if err != nil {			fmt.Println("error add node ", err)		}	}	return impl, nil}func NewConsistencyHash2(factor int, nodes...*Node) (ConsistencyHash, error) {	impl := &conHashImpl2{		vitrualFacto: factor,		nodes: treemap.NewWith(func(a, b interface{}) int {			aa, _:= a.(uint32)			bb, _ := b.(uint32)			if aa > bb {				return 1			} else if aa == bb {				return 0			} else {				return -1			}		}),	}	for _, node := range nodes {		err := impl.AddNode(node)		if err != nil {			fmt.Println("error add node ", err)		}	}	return impl, nil}type conHashImpl2 struct {	vNodesPerNode int	nodes         *treemap.Map	vitrualFacto int}func (c *conHashImpl2) GetNode(s string) *Node {	_, v := c.nodes.Ceiling(hashCode(s))	if v == nil {		_, v = c.nodes.Min()	}	if n, yes := v.(*VirtualNode); yes {		return n.node	}	fmt.Println("  error: ", s, v)	return nil}func (c *conHashImpl2) AddNode(node *Node) error {	for i := 0; i < c.vitrualFacto; i++ {		name := fmt.Sprintf("%s.%d", node.Name, i)		hash := hashCode(name)		c.nodes.Put(hash, &VirtualNode{name: name, hash: hash, node: node})	}	return nil}
测试代码(golang)
package consistency_hashimport (	"fmt"	"math"	"math/rand"	"testing")import "github.com/stretchr/testify/require"func TestCeiling(t *testing.T) {	assert := require.New(t)	r := &Ring{}	r.VirtualNodes = []*VirtualNode{		{hash: 10, between: [2]uint32{0, 10}},		{hash: 100, between: [2]uint32{11, 100}},		{hash: 1000, between: [2]uint32{101, 1000}},	}	pos := r.ceiling(5)	assert.Equal(int32(0), pos)	pos = r.ceiling(14)	assert.Equal(int32(1), pos)	pos = r.ceiling(100)	assert.Equal(int32(1), pos)	pos = r.ceiling(1000)	assert.Equal(int32(2), pos)	pos = r.ceiling(1100)	assert.Equal(int32(-1), pos)}func TestCeiling2(t *testing.T) {	assert := require.New(t)	r := &Ring{}	r.VirtualNodes = []*VirtualNode{		{hash: 10, between: [2]uint32{0, 10}},		{hash: 1000, between: [2]uint32{11, 10000}},	}	pos := r.ceiling(5)	assert.Equal(int32(0), pos)	pos = r.ceiling(14)	assert.Equal(int32(1), pos)}func TestAddNode(t *testing.T) {	assert := require.New(t)	r := &Ring{}	r.addNode(VirtualNode{hash: 10})	r.addNode(VirtualNode{hash: 10000})	r.addNode(VirtualNode{hash: 100})	r.addNode(VirtualNode{hash: 100000})	r.addNode(VirtualNode{hash: 1000})	r.addNode(VirtualNode{hash: 1000})	fmt.Println("=====", r.VirtualNodes)	pos := r.ceiling(5)	assert.Equal(int32(0), pos)	pos = r.ceiling(14)	assert.Equal(int32(1), pos)	pos = r.ceiling(100)	assert.Equal(int32(1), pos)	pos = r.ceiling(1000)	assert.Equal(int32(2), pos)	pos = r.ceiling(1100)	assert.Equal(int32(3), pos)	pos = r.ceiling(1000001)	assert.Equal(int32(-1), pos)}func TestHashcode(t *testing.T) {	assert := require.New(t)	code1 := hashCode("node1.1")	code2 := hashCode("node1.2")	fmt.Println(code1, code2)	assert.NotEqual(code1, code2)	values := [10]int{}	for j := 0; j < 1000000; j++ {		code := hashCode(fmt.Sprintf("node0.%d", j))		idx := code % uint32(len(values))		values[idx] = values[idx] + 1	}	MakeIterator(func(it chan uint32) {		for _, v := range values {			it <- uint32(v)		}		close(it)	}).calcSD()}func TestNewConsistencyHash(t *testing.T) {	assert := require.New(t)	factor := 35	cHash, err := NewConsistencyHash(factor, &Node{"node1"})	assert.NotNil(cHash)	assert.NoError(err)	impl := cHash.(*conHashImpl)	assert.Equal(factor, len(impl.ring.VirtualNodes))	node := cHash.GetNode("a")	assert.NotNil(node)	fmt.Println("a => ", node)	node = cHash.GetNode("b")	assert.NotNil(node)	fmt.Println("b => ", node)	err = cHash.AddNode(&Node{"node2"})	assert.NoError(err)	assert.Equal(factor*2, len(impl.ring.VirtualNodes))	node = cHash.GetNode("a")	assert.NotNil(node)	fmt.Println("a => ", node)	node = cHash.GetNode("b")	assert.NotNil(node)	fmt.Println("b => ", node)	node = cHash.GetNode("c")	assert.NotNil(node)	fmt.Println("c => ", node)	node = cHash.GetNode("d")	assert.NotNil(node)	fmt.Println("d => ", node)}func TestBalanced50(t *testing.T) {	assert := require.New(t)	cHash, err := NewConsistencyHash(50)	assert.NoError(err)	assert.NotNil(cHash)	for i := 0; i < 10; i++ {		cHash.AddNode(&Node{Name: fmt.Sprintf("node.%d", i)})	}	stat := make(map[string]int, 10)	for j := 0; j < 1000000; j++ {		key := rand.Int()		node := cHash.GetNode(fmt.Sprintf("key:%d", key))		stat[node.Name] = stat[node.Name] + 1	}	MakeIterator(func(it chan uint32) {		for _, v := range stat {			it <- uint32(v)		}		close(it)	}).calcSD()	fmt.Println("节点分布:", stat)}func TestBalanced100(t *testing.T) {	assert := require.New(t)	cHash, err := NewConsistencyHash(100)	assert.NoError(err)	assert.NotNil(cHash)	for i := 0; i < 10; i++ {		cHash.AddNode(&Node{Name: fmt.Sprintf("node.%d", i)})	}	impl := cHash.(*conHashImpl)	sum := make(map[string]uint32)	for i := 0; i < len(impl.ring.VirtualNodes); i++ {		node := impl.ring.VirtualNodes[i]		sum[node.name[:6]] = sum[node.name[:6]] + node.between[1] - node.between[0]	}	fmt.Println(sum)	stat := make(map[string]int, 10)	for j := 0; j < 1000000; j++ {		key := rand.Int()		node := cHash.GetNode(fmt.Sprintf("key:%d", key))		stat[node.Name] = stat[node.Name] + 1	}	MakeIterator(func(it chan uint32) {		for _, v := range stat {			it <- uint32(v)		}		close(it)	}).calcSD()	fmt.Println("节点分布:", stat)}func TestBalanced150(t *testing.T) {	assert := require.New(t)	cHash, err := NewConsistencyHash(150)	assert.NoError(err)	assert.NotNil(cHash)	for i := 0; i < 10; i++ {		cHash.AddNode(&Node{Name: fmt.Sprintf("node.%d", i)})	}	stat := make(map[string]int, 10)	for j := 0; j < 1000000; j++ {		key := rand.Int()		node := cHash.GetNode(fmt.Sprintf("key:%d", key))		stat[node.Name] = stat[node.Name] + 1	}	MakeIterator(func(it chan uint32) {		for _, v := range stat {			it <- uint32(v)		}		close(it)	}).calcSD()	fmt.Println("节点分布:", stat)}func TestBalanced200(t *testing.T) {	assert := require.New(t)	cHash, err := NewConsistencyHash(200)	assert.NoError(err)	assert.NotNil(cHash)	for i := 0; i < 10; i++ {		cHash.AddNode(&Node{Name: fmt.Sprintf("node.%d", i)})	}	stat := make(map[string]int, 10)	for j := 0; j < 1000000; j++ {		key := rand.Int()		node := cHash.GetNode(fmt.Sprintf("key:%d", key))		stat[node.Name] = stat[node.Name] + 1	}	MakeIterator(func(it chan uint32) {		for _, v := range stat {			it <- uint32(v)		}		close(it)	}).calcSD()	fmt.Println("节点分布:", stat)}func TestBalanced300(t *testing.T) {	assert := require.New(t)	cHash, err := NewConsistencyHash(300)	assert.NoError(err)	assert.NotNil(cHash)	for i := 0; i < 10; i++ {		cHash.AddNode(&Node{Name: fmt.Sprintf("node.%d", i)})	}	stat := make(map[string]int, 10)	for j := 0; j < 1000000; j++ {		key := rand.Int()		node := cHash.GetNode(fmt.Sprintf("key:%d", key))		stat[node.Name] = stat[node.Name] + 1	}	MakeIterator(func(it chan uint32) {		for _, v := range stat {			it <- uint32(v)		}		close(it)	}).calcSD()	fmt.Println("节点分布:", stat)}func TestBalanced2_300(t *testing.T) {	assert := require.New(t)	cHash, err := NewConsistencyHash2(300)	assert.NoError(err)	assert.NotNil(cHash)	for i := 0; i < 10; i++ {		cHash.AddNode(&Node{Name: fmt.Sprintf("node.%d", i)})	}	stat := make(map[string]int, 10)	for j := 0; j < 1000000; j++ {		key := rand.Int()		node := cHash.GetNode(fmt.Sprintf("key:%d", key))		stat[node.Name] = stat[node.Name] + 1	}	MakeIterator(func(it chan uint32) {		for _, v := range stat {			it <- uint32(v)		}		close(it)	}).calcSD()	fmt.Println("节点分布:", stat)}type MakeIterator func(chan uint32)func (sd MakeIterator) reset() chan uint32 {	it := make(chan uint32)	go sd(it)	return it}func (sd MakeIterator) calcSD() {	//sd:	it := sd.reset()	sum := uint64(0)	count := uint64(0)	for v := range it {		sum += uint64(v)		count += 1	}	avg := sum / count	fmt.Println("avg: ", avg)	it = sd.reset()	sum2 := float64(0)	for v := range it {		sum2 += (float64(v) - float64(avg)) * (float64(v) - float64(avg))	}	f := sum2 / float64(count)	fmt.Println(" 方差: ", f)	r := math.Sqrt(f)	fmt.Println("标准差 ", r)}
测试结果
GOROOT=/usr/local/go #gosetupGOPATH=/Users/develper/go #gosetup/usr/local/go/bin/go test -c -o /private/var/folders/yr/b1sl4_z97g9grtpfdq4p25bc0000gn/T/___hash_test_go unipus.cn/misc/consistency_hash #gosetup/usr/local/go/bin/go tool test2json -t /private/var/folders/yr/b1sl4_z97g9grtpfdq4p25bc0000gn/T/___hash_test_go -test.v -test.run "^TestCeiling|TestCeiling2|TestAddNode|TestHashcode|TestNewConsistencyHash|TestBalanced50|TestBalanced100|TestBalanced150|TestBalanced200|TestBalanced300|TestBalanced2_300$" #gosetup=== RUN   TestCeiling--- PASS: TestCeiling (0.00s)=== RUN   TestCeiling2--- PASS: TestCeiling2 (0.00s)=== RUN   TestAddNode===== [: 10 - [0 10] : 100 - [11 100] : 1000 - [101 1000] : 10000 - [1001 10000] : 100000 - [10001 100000]]--- PASS: TestAddNode (0.00s)=== RUN   TestHashcode2570533916 183424597avg:  100000 方差:  33568.6标准差  183.2173572563473--- PASS: TestHashcode (0.44s)=== RUN   TestNewConsistencyHasha =>  &{node1}b =>  &{node1}a =>  &{node1}b =>  &{node1}c =>  &{node2}d =>  &{node2}--- PASS: TestNewConsistencyHash (0.00s)=== RUN   TestBalanced50avg:  100000 方差:  2.18434133e+08标准差  14779.517346652427节点分布: map[node.0:111141 node.1:132610 node.2:91803 node.3:87149 node.4:87838 node.5:91314 node.6:105140 node.7:80381 node.8:111317 node.9:101307]--- PASS: TestBalanced50 (0.66s)=== RUN   TestBalanced100map[node.0:446011767 node.1:393943837 node.2:449561198 node.3:436100962 node.4:387096975 node.5:380953707 node.6:479438298 node.7:418214358 node.8:466233340 node.9:427361096]avg:  100000 方差:  5.93047092e+07标准差  7700.955083624369节点分布: map[node.0:106398 node.1:91814 node.2:104668 node.3:101523 node.4:89681 node.5:88410 node.6:111479 node.7:97442 node.8:109037 node.9:99548]--- PASS: TestBalanced100 (0.68s)=== RUN   TestBalanced150avg:  100000 方差:  2.89699606e+07标准差  5382.374996226108节点分布: map[node.0:104143 node.1:106628 node.2:88936 node.3:96565 node.4:103503 node.5:104951 node.6:99403 node.7:104164 node.8:97516 node.9:94191]--- PASS: TestBalanced150 (0.70s)=== RUN   TestBalanced200avg:  100000 方差:  4.0203905e+07标准差  6340.654934626233节点分布: map[node.0:102471 node.1:94765 node.2:99744 node.3:102159 node.4:102237 node.5:93901 node.6:87518 node.7:106117 node.8:99760 node.9:111328]--- PASS: TestBalanced200 (0.70s)=== RUN   TestBalanced300avg:  100000 方差:  1.41155058e+07标准差  3757.060792694204节点分布: map[node.0:104197 node.1:101440 node.2:105330 node.3:97980 node.4:96570 node.5:92179 node.6:101024 node.7:97780 node.8:103154 node.9:100346]--- PASS: TestBalanced300 (0.73s)=== RUN   TestBalanced2_300avg:  100000 方差:  2.63993778e+07标准差  5138.032483353915节点分布: map[node.0:104364 node.1:100916 node.2:97636 node.3:98447 node.4:88917 node.5:101350 node.6:105112 node.7:93866 node.8:106036 node.9:103356]--- PASS: TestBalanced2_300 (0.82s)PASSProcess finished with exit code 0
结论:
在每个节点对应200个虚拟节点时的标准差最小,数据的分布最均匀。

麻辣
还未添加个人签名 2018.10.13 加入
还未添加个人简介











 
    
评论