背景
刚接触 spark-streaming,然后写了一个 WordCount 程序,对于不停流进来的数据,需要累加单词出现的次数,这时就需要把前一段时间的结果持久化,而不是数据计算过后就抛弃,在网上搜索到 spark-streaming 可以通过updateStateByKey
和mapWithState
来实现这种有状态的流管理,后者虽然在 spark1.6.x 还是一个实验性的实现,不过由于它的实现思想以及性能都不错,所以简单通过阅读一下源码了解一下原理(其实太深入也不定看得懂~~).网上也有一些文章对比分析两种实现方式,不过始终觉得自己动手写写画画,按照自己的思路可能更有利于加深理解。
有状态的流,在 WordCount 这个例子中,状态(State)就是表示某个单词(Word)在这些流过的数据里面出现过的次数(Count),就像无状态的 http 一样通过 Cookie 或者 Session 来维持交互的状态,而流也一样,通过mapWithState
实现了有状态的流。
工作流程
个人理解,mapWithState
实现有状态管理主要是通过两点:a)历史状态需要在内存中维护,这里必需的了,updateStateBykey
也是一样。b)自定义更新状态的 mappingFunction,这些就是具体的业务功能实现逻辑了(什么时候需要更新状态)
简单地画了一下mapWithState
的工作过程,首先数据像水流一样从左侧的箭头流入,把mapWithState
看成一个转换器的话,mappingFunc 就是转换的规则,流入的新数据(key-value)结合历史状态(通过 key 从内存中获取的历史状态)进行一些自定义逻辑的更新等操作,最终从红色箭头中流出。
具体类及关系
主要涉及的类并不是很多,不过对 scala 不熟还是花了不少时间去看,一些具体的计算逻辑在:MapWithStateStreamImpl,InternalMapWithStateStream,MapWithStateRDD,MapWithStateRDDRecord
.整个图可以从上至下看,从调用mapWithState
到底层的存储,或者说从 DStream->MapWithStateRDD(RDD[MapWithStateRDDRecord])->MapWithStateRDDRecord->Map 这种有层次关系的结构,结合下面相应的源码片断会更容易理解。
源码分析
以 WordCount 例子为切入点,这个例子是从 kafka 接收数据的,为了方便测试的话直接在 sock 获取数据即可:
val sparkConf = new SparkConf().setAppName("WordCount").setMaster("local")
val ssc = new StreamingContext(sparkConf, Seconds(10))
ssc.checkpoint("d:\\tmp")
val params = Map("bootstrap.servers" -> "master:9092", "group.id" -> "scala-stream-group")
val topic = Set("test")
val initialRDD = ssc.sparkContext.parallelize(List[(String, Int)]())
val messages = KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder](ssc, params, topic)
val word = messages.flatMap(_._2.split(" ")).map { x => (x, 1) }
//自定义mappingFunction,累加单词出现的次数并更新状态
val mappingFunc = (word: String, count: Option[Int], state: State[Int]) => {
val sum = count.getOrElse(0) + state.getOption.getOrElse(0)
val output = (word, sum)
state.update(sum)
output
}
//调用mapWithState进行管理流数据的状态
val stateDstream = word.mapWithState(StateSpec.function(mappingFunc).initialState(initialRDD)).print()
ssc.start()
ssc.awaitTermination()
复制代码
进入mapWithState
,发现这是PairDStreamFunctions
的方法,由于在 DStream 伴生对象中定义了隐式转换函数toPairDStreamFunctions
,使得将 DStream 类型隐式转换成了 PairDStreamFunctions 类型,从而可以使用mapWithState
方法:
implicit def toPairDStreamFunctions[K, V](stream: DStream[(K, V)])
(implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null):
PairDStreamFunctions[K, V] = {
new PairDStreamFunctions[K, V](stream)
}
复制代码
看一下PairDStreamFunctions.mapWithState()
方法:
@Experimental
def mapWithState[StateType: ClassTag, MappedType: ClassTag](
spec: StateSpec[K, V, StateType, MappedType]
): MapWithStateDStream[K, V, StateType, MappedType] = {
new MapWithStateDStreamImpl[K, V, StateType, MappedType](
self, #这是本批次DStream对象
spec.asInstanceOf[StateSpecImpl[K, V, StateType, MappedType]]
)
}
复制代码
mapWithState
在 spark1.6.x 还处于试验阶段(@Experimental),通过StateSpec
包装了 mappingFunction 函数(还可以设置状态的超时时间、初始化的状态数据等),然后将StateSpec
作为参数传给mapWithState
。mappingFunction 函数将会被应用于本批次流数据的所有记录(key-value),并与对应 key(唯一)的历史状态(通过 key 获取的 State 对象)进行相应的逻辑运算(就是 mappingFunction 的函数体)。
在mapWithState
中以本批次 DStream 对象和 StateSpecImpl(它是 StateSpec 的子类)的实例对象作为参数构造出 MapWithStateDStreamImpl(MapWithStateDStream 的子类)。从名称就可以看出它是代表着一个有状态的流,看一下它的一个属性和一个方法:
private val internalStream =
new InternalMapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream, spec)
......
# 这个方法只是把需要返回的数据进行一下转换,主要的逻辑在internalStream的computer里面
override def compute(validTime: Time): Option[RDD[MappedType]] = {
internalStream.getOrCompute(validTime).map { _.flatMap[MappedType] { _.mappedData } }
}
复制代码
MapWithStateDStreamImpl
中没有什么逻辑处理,在 computer 中调用了internalStream.getOrCompute
,这个其实是父类 DStream 的方法,然后在DStream.getOrCompute
又回调了子类InternalMapWithStateDStream.compute
,那么看下这个方法的内容:
InternalMapWithStateDStream: line 132
/** Method that generates a RDD for the given time */
override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = {
// 计算上一个状态的RDD
val prevStateRDD = getOrCompute(validTime - slideDuration) match {
case Some(rdd) =>
// 这个rdd是RDD[MapWithStateRDDRecord[K, S, E]]类型,其实就是一个MapWithStateRDD
if (rdd.partitioner != Some(partitioner)) {
// If the RDD is not partitioned the right way, let us repartition it using the
// partition index as the key. This is to ensure that state RDD is always partitioned
// before creating another state RDD using it
// _.stateMap.getAll()就是获取了所有的状态数据(往会下看会明白),在此基础创建MapWithStateRDD
MapWithStateRDD.createFromRDD[K, V, S, E](
rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime)
} else {
rdd
}
case None =>
// 首批数据流入时会进入该分支
MapWithStateRDD.createFromPairRDD[K, V, S, E](
spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
partitioner,
validTime
)
}
// 本批次流数据构成的RDD:dataRDD
val dataRDD = parent.getOrCompute(validTime).getOrElse {
context.sparkContext.emptyRDD[(K, V)]
}
val partitionedDataRDD = dataRDD.partitionBy(partitioner) // 重新分区
val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
(validTime - interval).milliseconds
}
// 构造MapWithStateRDD
Some(new MapWithStateRDD(
prevStateRDD, partitionedDataRDD, mappingFunction, validTime, timeoutThresholdTime))
}
}
复制代码
compute 方法是通过时间来生成 MapWithStateRDD 的,首先获取到前一个状态的 prevStateRDD(validTime - 一次 batch 时间):
如果可以得到结果 MapWithStateRDD
判断该 RDD 是否已经正确分区,后面的计算是分别每个分区对应的旧状态与新流入的数据进行计算的,如果两者的分区不一致,会导致计算出错(即会更新到错误的状态)。如果正确分区了,那么直接返回 RDD,否则通过MapWithStateRDD.createFromRDD
指定 partitioner 对数据进行重新分区
如果计算不可以得到结果 MapWithStateRDD
通过MapWithStateRDD.createFromPairRDD
创建一个,如果在 WordCount 代码中有对 StateSpec 进行设置 initialStateRDD 初始值,那么将在 initialStateRDD 基础上创建 MapWithStateRDD,否则创建一个空的 MapWithStateRDD,注意,这里使用的是同一个 partitioner 以保证具有相同的分区
然后再从 DStream 中获取本次 batch 数据的 dataRDD(validTime),如果没有,则重建一个空的 RDD,然后将该 RDD 重新分区,使得与上面的 MapWithStateRDD 具有相同的分区。
最后就是构造 MapWithStateRDD 对象:
new MapWithStateRDD(
prevStateRDD, partitionedDataRDD, mappingFunction, validTime, timeoutThresholdTime)
复制代码
上面 WordCount 例子中的 mappingFunc 函数就是这里的 mappingFunction,而参数 word 与 count 来自 partitionedDataRDD 中的数据,参数 state 来自 prevStateRDD 中的数据,基本已经符合上文中第一张图了。
MapWithStateRDD 包含了前一个状态数据(preStateRDD)以及本次需要进行更新的流数据(dataRDD),再看看它的 compute 过程:
override def compute(
partition: Partition, context: TaskContext): Iterator[MapWithStateRDDRecord[K, S, E]] = {
val stateRDDPartition = partition.asInstanceOf[MapWithStateRDDPartition]
val prevStateRDDIterator = prevStateRDD.iterator(
stateRDDPartition.previousSessionRDDPartition, context)
val dataIterator = partitionedDataRDD.iterator(
stateRDDPartition.partitionedDataRDDPartition, context)
# prevRecord代表prevStateRDD一个分区的数据
val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None
val newRecord = MapWithStateRDDRecord.updateRecordWithData(
prevRecord,
dataIterator,
mappingFunction,
batchTime,
timeoutThresholdTime,
removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled
)
Iterator(newRecord)
}
复制代码
这个方法把前一个状态的 RDD 以及本 batch 的 RDD 交给MapWithStateRDDRecord.updateRecordWithData
进行计算,看下面的代码片段,这里有一个不明白的地方:为什么只获取了 prevStateRDD 的第一个分区数据(preRecord),有人明白的求解答。
def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
prevRecord: Option[MapWithStateRDDRecord[K, S, E]],
dataIterator: Iterator[(K, V)],
mappingFunction: (Time, K, Option[V], State[S]) => Option[E],
batchTime: Time,
timeoutThresholdTime: Option[Long],
removeTimedoutData: Boolean
): MapWithStateRDDRecord[K, S, E] = {
// 从prevRecord中拷贝历史的状态数据(保存在MapWithStateRDDRecord的stateMap属性中)
val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() }
// 保存需要返回的数据
val mappedData = new ArrayBuffer[E]
val wrappedState = new StateImpl[S]() // State子类,代表一个状态
// 循环本batch的所有记录(key-value)
dataIterator.foreach { case (key, value) =>
// 根据key从历史状态数据中获取该key对应的历史状态
wrappedState.wrap(newStateMap.get(key))
// 将mappingFunction应用到本次batch的每条记录
val returned = mappingFunction(batchTime, key, Some(value), wrappedState)
if (wrappedState.isRemoved) {
newStateMap.remove(key)
} else if (wrappedState.isUpdated
|| (wrappedState.exists && timeoutThresholdTime.isDefined)) {
newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
}
mappedData ++= returned
} // 判断是否删除、更新等然后维护newStateMap里的状态,并记录返回数据mappedData
// 如果有配置超时,也会进行一些更新操作
if (removeTimedoutData && timeoutThresholdTime.isDefined) {
newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
wrappedState.wrapTimingOutState(state)
val returned = mappingFunction(batchTime, key, None, wrappedState)
mappedData ++= returned
newStateMap.remove(key)
}
}
// 依然是返回一个MapWithStateRDDRecord,只是里面的数据变了
MapWithStateRDDRecord(newStateMap, mappedData)
}
复制代码
在这里面,循环本 batch 的数据,通过 key 获取到对应的状态并应用 mappingFunction 函数到每条记录,在 WordCount 例子中,假如某条记录 key="hello",先获取它的状态(假如它的上一个状态是 3,即出现过了 3 次):
# word="hello",value=1,state=("hello",3)
val mappingFunc = (word: String, value: Option[Int], state: State[Int]) => {
val sum = value.getOrElse(0) + state.getOption.getOrElse(0) #sum=4
val output = (word, sum) #("hello",4)
state.update(sum) #更新key="hello"的状态为4
output # 返回("hello",4)
}
复制代码
MapWithStateRDD 代表着数据状态能够更新的 RDD,但 RDD 本身是不可变的,MapWithStateRDD 的元素是 MapWithStateRDDRecord,而 MapWithStateRDDRecord 主要由 stateMap:StateMap(保存着历史的状态)以及 mappedData:Seq(保存返回,即流出的数据)这两个属性构成,当进行 mapWithState 转换时,根据上文的源码分析可知,其实就是对这两个数据结构进行操作,最终变化的只是这两个属性的数据,而 MapWithStateRDD 的元素依然是 MapWithStateRDDRecord。
StateMap 的内部是了个 OpenHashMapBasedStateMap,当调用 create 方法时会创建一个对象:
/** Companion object for [[StateMap]], with utility methods */
private[streaming] object StateMap {
def empty[K, S]: StateMap[K, S] = new EmptyStateMap[K, S]
def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = {
val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold",
DELTA_CHAIN_LENGTH_THRESHOLD)
new OpenHashMapBasedStateMap[K, S](deltaChainThreshold)
}
}
复制代码
而 OpenHashMapBasedStateMap 又是基于 spark 的一个数据结构org.apache.spark.util.collection.OpenHashMap
,至于再底层的实现有兴趣再看,就不再往下分析了。本文主要是围绕主线展开,还有很多分支的细节是需要去了解清楚,这样更能加深理解。
参考
http://www.jianshu.com/p/1463bc1d81b5
http://www.cnblogs.com/DT-Spark/articles/5616560.html
https://my.oschina.net/corleone/blog/684215
https://spark.apache.org/docs/1.6.1/api/scala/index.html
评论