為了讓大家理解structured stream的運行流程,我將根據一個代碼例子,講述structured stream的基本運行流程和原理。
下面是一段簡單的代碼:
1 val spark = SparkSession
2 .builder
3 .appName("StructuredNetworkWordCount")
4 .master("local[4]")
5
6 .getOrCreate()
7 spark.conf.set("spark.sql.shuffle.partitions", 4)
8
9 import spark.implicits._
10 val words = spark.readStream
11 .format("socket")
12 .option("host", "localhost")
13 .option("port", 9999)
14 .load()
15
16 val df1 = words.as[String]
17 .flatMap(_.split(" "))
18 .toDF("word")
19 .groupBy("word")
20 .count()
21
22 df1.writeStream
23 .outputMode("complete")
24 .format("console")
25 .trigger(ProcessingTime(10))
26 .start()
27
28 spark.streams.awaitAnyTermination()
這段代碼就是單詞計數。先從一個socket數據源讀入數據,然後以" " 為分隔符把一行文本轉換成單詞的DataSet,然後轉換成有標籤("word")的DataFrame,接着按word列進行分組,聚合計算每個word的個數。最後輸出到控制枱,以10秒為批處理執行週期。
現在來分析它的原理。spark的邏輯裏面有一個惰性計算的概念,以上面的例子來説,在第22行代碼以前,程序都不會對數據進行真正的計算,而是將計算的公式(或者函數)保存在DataFrame裏面,在22行開始的writeStream.start調用後才開始真正的計算。為什麼?
因為:
這可以讓spark內核做一些優化。
例如:
數據庫中存放着人的名字和年齡,我想要在控制枱打印出前十個年齡大於20歲的人的名字,那麼我的spark代碼會這麼寫:
1 df.fileter{row=>
2 row._2>20}
3 .show(10)
假如説我每執行一行代碼就進行一次計算,那麼在第二行的時候,我就會把df裏面所有的數據進行過濾,篩選出其中年齡大於20的,然後在第3行執行的時候,從第2行裏面的結果中選前面10個進行打印。
看出問題了麼?這裏的輸出僅僅只需要10個年齡大於20的人,但是我卻把所有人都篩選了一遍,其實我只需要篩選出10個,後面的就不必要篩選了。這就是spark的惰性計算進行優化的地方。
在spark的計算中,在真正的輸出函數之前,都不會進行真正的計算,而會在輸出函數之前進行優化後再進行計算。我們來看源代碼。
這裏我貼的是structured stream每次批處理週期到達時會運行的代碼:
1 private def runBatch(sparkSessionToRunBatch: SparkSession): Unit = {
2 // Request unprocessed data from all sources.
3 newData = reportTimeTaken("getBatch") {
4 availableOffsets.flatMap {
5 case (source, available)
6 if committedOffsets.get(source).map(_ != available).getOrElse(true) =>
7 val current = committedOffsets.get(source)
8 val batch = source.getBatch(current, available)
9 logDebug(s"Retrieving data from $source: $current -> $available")
10 Some(source -> batch)
11 case _ => None
12 }
13 }
14
15 // A list of attributes that will need to be updated.
16 var replacements = new ArrayBuffer[(Attribute, Attribute)]
17 // Replace sources in the logical plan with data that has arrived since the last batch.
18 val withNewSources = logicalPlan transform {
19 case StreamingExecutionRelation(source, output) =>
20 newData.get(source).map { data =>
21 val newPlan = data.logicalPlan
22 assert(output.size == newPlan.output.size,
23 s"Invalid batch: ${Utils.truncatedString(output, ",")} != " +
24 s"${Utils.truncatedString(newPlan.output, ",")}")
25 replacements ++= output.zip(newPlan.output)
26 newPlan
27 }.getOrElse {
28 LocalRelation(output)
29 }
30 }
31
32 // Rewire the plan to use the new attributes that were returned by the source.
33 val replacementMap = AttributeMap(replacements)
34 val triggerLogicalPlan = withNewSources transformAllExpressions {
35 case a: Attribute if replacementMap.contains(a) => replacementMap(a)
36 case ct: CurrentTimestamp =>
37 CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs,
38 ct.dataType)
39 case cd: CurrentDate =>
40 CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs,
41 cd.dataType, cd.timeZoneId)
42 }
43
44 reportTimeTaken("queryPlanning") {
45 lastExecution = new IncrementalExecution(
46 sparkSessionToRunBatch,
47 triggerLogicalPlan,
48 outputMode,
49 checkpointFile("state"),
50 currentBatchId,
51 offsetSeqMetadata)
52 lastExecution.executedPlan // Force the lazy generation of execution plan
53 }
54
55 val nextBatch =
56 new Dataset(sparkSessionToRunBatch, lastExecution, RowEncoder(lastExecution.analyzed.schema))
57
58 reportTimeTaken("addBatch") {
59 sink.addBatch(currentBatchId, nextBatch)
60 }
61
62 awaitBatchLock.lock()
63 try {
64 // Wake up any threads that are waiting for the stream to progress.
65 awaitBatchLockCondition.signalAll()
66 } finally {
67 awaitBatchLock.unlock()
68 }
69 }
其實很簡單,在第58以前都是在解析用户代碼,生成logicPlan,優化logicPlan,生成批處理類。第47行的triggerLogicalPlan就是最終優化後的用户邏輯,它被封裝在了一個IncrementalExecution類中,這個類連同sparkSessionToRunBatch(運行環境)和RowEncoder(序列化類)一起構成一個新的DataSet,這個DataSet就是最終要發送到worker節點進行執行的代碼。第59行代碼就是在將它加入到準備發送代碼的隊列中。我們繼續看一段代碼,由於我們使用console作為數據下游(sink)所以看看console的addBatch代碼:
1 override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized {
2 val batchIdStr = if (batchId <= lastBatchId) {
3 s"Rerun batch: $batchId"
4 } else {
5 lastBatchId = batchId
6 s"Batch: $batchId"
7 }
8
9 // scalastyle:off println
10 println("-------------------------------------------")
11 println(batchIdStr)
12 println("-------------------------------------------")
13 // scalastyle:off println
14 data.sparkSession.createDataFrame(
15 data.sparkSession.sparkContext.parallelize(data.collect()), data.schema)
16 .show(numRowsToShow, isTruncated)
17 }
關鍵代碼在16行.show函數,show函數是一個真正的action,在這之前都是一些算子的封裝,我們看show的代碼:
1 private[sql] def showString(_numRows: Int, truncate: Int = 20): String = {
2 val numRows = _numRows.max(0)
3 val takeResult = toDF().take(numRows + 1)
4 val hasMoreData = takeResult.length > numRows
5 val data = takeResult.take(numRows)
第3行進入take:
def take(n: Int): Array[T] = head(n)
def head(n: Int): Array[T] = withAction("head", limit(n).queryExecution)(collectFromPlan)
1 private def withAction[U](name: String, qe: QueryExecution)(action: SparkPlan => U) = {
2 try {
3 qe.executedPlan.foreach { plan =>
4 plan.resetMetrics()
5 }
6 val start = System.nanoTime()
7 val result = SQLExecution.withNewExecutionId(sparkSession, qe) {
8 action(qe.executedPlan)
9 }
10 val end = System.nanoTime()
11 sparkSession.listenerManager.onSuccess(name, qe, end - start)
12 result
13 } catch {
14 case e: Exception =>
15 sparkSession.listenerManager.onFailure(name, qe, e)
16 throw e
17 }
18 }
這個函數名就告訴我們,這是真正計算要開始了,第7行代碼一看就是準備發送代碼序列了:
1 def withNewExecutionId[T](
2 sparkSession: SparkSession,
3 queryExecution: QueryExecution)(body: => T): T = {
4 val sc = sparkSession.sparkContext
5 val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY)
6 if (oldExecutionId == null) {
7 val executionId = SQLExecution.nextExecutionId
8 sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString)
9 executionIdToQueryExecution.put(executionId, queryExecution)
10 val r = try {
11 // sparkContext.getCallSite() would first try to pick up any call site that was previously
12 // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on
13 // streaming queries would give us call site like "run at <unknown>:0"
14 val callSite = sparkSession.sparkContext.getCallSite()
15
16 sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart(
17 executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
18 SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
19 try {
20 body
21 } finally {
22 sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd(
23 executionId, System.currentTimeMillis()))
24 }
25 } finally {
26 executionIdToQueryExecution.remove(executionId)
27 sc.setLocalProperty(EXECUTION_ID_KEY, null)
28 }
29 r
30 } else {
31 // Don't support nested `withNewExecutionId`. This is an example of the nested
32 // `withNewExecutionId`:
33 //
34 // class DataFrame {
35 // def foo: T = withNewExecutionId { something.createNewDataFrame().collect() }
36 // }
37 //
38 // Note: `collect` will call withNewExecutionId
39 // In this case, only the "executedPlan" for "collect" will be executed. The "executedPlan"
40 // for the outer DataFrame won't be executed. So it's meaningless to create a new Execution
41 // for the outer DataFrame. Even if we track it, since its "executedPlan" doesn't run,
42 // all accumulator metrics will be 0. It will confuse people if we show them in Web UI.
43 //
44 // A real case is the `DataFrame.count` method.
45 throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set")
46 }
47 }
你看第16行,就是在發送數據,包括用户優化後的邏輯,批處理的id,時間戳等等。worker接收到這個事件後根據logicalPlan裏面的邏輯就開始幹活了。這就是一個很基本很簡單的流程,對於spark入門還是挺有幫助的吧。