1 核心概念:為什麼需要它們?

在spark程序中,當一個傳遞給Spark操作(例如map和reduce)的函數在遠程節點上面運行時,Spark操作實際上操作的是這個函數所用變量的一個獨立副本。這些變量會被複制到每台機器上,並且這些變量在遠程機器上的所有更新都不會傳遞迴驅動程序。通常跨任務的讀寫變量是低效的,但是,Spark還是為兩種常見的使用模式提供了兩種有限的共享變量:廣播變(broadcast variable)和累加器(accumulator):

  • 廣播變量:用於將 Driver 端的一個大型只讀數據分發到所有 Executor 端。(Driver -> Executor 的分發)
  • 累加器:用於將 Executor 端的信息聚合回 Driver 端。(Executor -> Driver 的聚合)
    它們都是為了在分佈式環境下減少不必要的網絡傳輸和序列化開銷,並確保操作的正確性和高效性。

2 廣播變量

如果我們要在分佈式計算裏面分發大對象,例如:字典,集合,黑白名單等,這個都會由Driver端進行分發,一般來講,如果這個變量不是廣播變量,那麼每個task就會分發一份,這在task數目十分多的情況下Driver的帶寬會成為系統的瓶頸,而且會大量消耗task服務器上的資源。廣播變量將Driver端的可序列化對象,通過高效的廣播協議分發到每個Executor進程的內存中。在每個Executor進程內,所有運行在該Executor上的task線程共享同一份廣播變量的內存引用,避免了數據的重複傳輸和多次反序列化。

val a = 3 
val broadcast = sc.broadcast(a)  //把變量a廣播出去
val c = broadcast.value  //從廣播變量裏獲取a

2.1 廣播變量不可修改

在 Spark 中,廣播變量傳遞的是對象的引用,而不是對象本身的副本。這意味着:

  • 廣播的是對象引用:每個 executor 獲得的是指向同一個對象的引用
  • 對象本身在 driver 上創建:對象首先在 driver 端創建,然後廣播到所有 executor
    如果廣播的是對象,從技術上來説,executor可以修改廣播變量的值,但是強烈不建議這麼做,有併發修改風險,多個 executor 同時修改同一個對象會導致數據競爭,修改結果無法保證一致性
    ,可能因 JVM 內存模型導致修改在不同 executor 間不可見。
// 推薦:使用不可變對象
case class Config(threshold: Int)  // val 不可變字段

val config = Config(100)
val broadcastConfig = sc.broadcast(config)

// 只能讀取,不能修改
rdd.map { data =>
  val currentThreshold = broadcastConfig.value.threshold  // ✅ 安全
  // ... 只讀操作
}

如果需要更新廣播變量:

var currentConfig = Config(100)
var broadcastConfig = sc.broadcast(currentConfig)

// 需要更新時
currentConfig = Config(200)
broadcastConfig.unpersist()  // 取消舊的廣播
broadcastConfig = sc.broadcast(currentConfig)  // 廣播新對象

2.2 利用廣播變量簡化大表和小表的join操作

兩個RDD進行join操作(即 rdd1.join(rdd2)) 會導致shuffle,這是因為join操作會對key一致的key-vlaue對進行合併,而key相同的key-value對不太可能會在同一個partition, 因此很有可能是需要進行經過網絡進行shuffle的,而shuffle會產生許多中間數據(小文件)並涉及到網絡傳輸,這些通常比較耗時,Spark中要儘量避免shuffle。

優化方法:將小RDD的數據通過broadcast到每個executor中,各大RDD partition分別和小RDD做join操作。
具體是:在driver端將小RDD轉換成數組array並broadcast到各executor端,然後再各executor task中對各partion的大RDD的key-value對和小rdd的key-value對進行join;由於每個executor端都有完整的小RDD,因此小RDD的各partition不需要shuffle到RDD的各partition,小RDD廣播到大RDD的各partition後,各partition分別進行join,最後再執行reduce,所有分區的join結果彙總到driver端。

import org.apache.spark.sql.SparkSession

object BigRDDJoinSmallRDD {

  def main(args: Array[String]): Unit = {

    val sparkSession = SparkSession.builder().master("local[3]").appName("BigRDD Join SmallRDD").getOrCreate()
    val sc = sparkSession.sparkContext
    val list1 = List(("jame",23), ("wade",3), ("kobe",24))
    val list2 = List(("jame", 13), ("wade",6), ("kobe",16))
    val bigRDD = sc.makeRDD(list1)
    val smallRDD = sc.makeRDD(list2)

    println(bigRDD.getNumPartitions)
    println(smallRDD.getNumPartitions)

    // driver端rdd不broadcast廣播smallRDD到各executor,RDD不能被broadcast,需要轉換成數組array
    val  smallRDDB= sc.broadcast(smallRDD.collect())

    val joinedRDD = bigRDD.mapPartitions(partition => {
      val smallRDDBV = smallRDDB.value  // 各個executor端的task讀取廣播value
      partition.map(element => {
        //println(joinUtil(element, smallRDDBV))
        joinUtil(element, smallRDDBV)
      })
    })
    joinedRDD.foreach(x => println(x))



  }


/**
  * join操作:對兩個rdd中的相同key的value1和value2進行聚合,即(key,value1).join(key,value2)得到(key,(value1, vlaue2))
  * 如果bigRDDEle的key和smallRDD的某個key一致,那麼返回(key,(value1, vlaue2))
  * 該方法會在各executor的task上執行
  * */
 def joinUtil(bigRDDEle:(String,Int), smallRDD: Array[(String, Int)]): (String, (Int,Int)) = {
   var joinEle:(String, (Int, Int)) = null
   // 遍歷數組smallRDD
   smallRDD.foreach(smallRDDEle => {
      if(smallRDDEle._1.equals(bigRDDEle._1)){
        // 如果bigRDD中某個元素的key和數組smallRDD的key一致,返回join結果
        joinEle = (bigRDDEle._1, (bigRDDEle._2, smallRDDEle._2))
      }
    })
   joinEle
 }

}

2.3 如何讓spark選擇廣播join

2.3.1 利用配置項強行廣播

使用廣播閾值配置項讓Spark優先選擇Broadcast Joins的關鍵,就是要確保至少有一張表的存儲尺寸小於廣播閾值(spark.sql.autoBroadcastJoinThreshold)。那麼如何估算一張表的大小呢?第一步,把要預估大小的數據表緩存到內存,比如直接在DataFrame或是Dataset上調用cache方法;第二步,讀取Spark SQL執行計劃的統計數據。這是因為,Spark SQL在運行時,就是靠這些統計數據來制定和調整執行策略的。

val df: DataFrame = _
df.cache.count
 
val plan = df.queryExecution.logical
val estimated: BigInt = spark
.sessionState
.executePlan(plan)
.optimizedPlan
.stats
.sizeInBytes

2.3.2 利用API強制廣播

2.3.2.1 Join Hints

Join Hints中的Hints表示“提示”,它指的是在開發過程中使用特殊的語法,明確告知Spark SQL在運行時採用哪種Join策略。一旦你啓用了Join Hints,不管你的數據表是不是滿足廣播閾值,Spark SQL都會盡可能地尊重你的意願和選擇,使用Broadcast Joins去完成數據關聯。
舉個例子,假設有兩張表,一張表的內存大小在100GB量級,另一張小一些,2GB左右。在廣播閾值被設置為2GB的情況下,並沒有觸發Broadcast Joins,但我們又不想花費時間和精力去精確計算小表的內存佔用到底是多大。在這種情況下,就可以用Join Hints來幫我們做優化,僅僅幾句提示就可以達到目的。

val table1: DataFrame = spark.read.parquet(path1)
val table2: DataFrame = spark.read.parquet(path2)
table1.createOrReplaceTempView("t1")
table2.createOrReplaceTempView("t2")
 
val query: String = “select /*+ broadcast(t2) */ * from t1 inner join t2 on t1.key = t2.key”
val queryResutls: DataFrame = spark.sql(query)

也可以在DataFrame的DSL語法中使用Join Hints:

table1.join(table2.hint(“broadcast”), Seq(“key”), “inner”)

不過,Join Hints也有個小缺陷。如果關鍵字拼寫錯誤,Spark SQL在運行時並不會顯示地拋出異常,而是默默地忽略掉拼寫錯誤的hints,假裝它壓根不存在。因此,在使用Join Hints的時候,需要我們在編譯時自行確認Debug和糾錯。

2.3.2.2 廣播函數

如果你不想等到運行時才發現問題,想讓編譯器幫你檢查類似的拼寫錯誤,那麼你可以使用強制廣播的第二種方式:broadcast函數。這個函數是類庫org.apache.spark.sql.functions中的broadcast函數。調用方式非常簡單,比Join Hints還要方便,只需要用broadcast函數封裝需要廣播的數據表即可,如下所示。

import org.apache.spark.sql.functions.broadcast
table1.join(broadcast(table2), Seq(“key”), “inner”)

你可能會問:“既然開發者可以通過Join Hints和broadcast函數強制Spark SQL選擇Broadcast Joins,那我是不是就可以不用理會廣播閾值的配置項了?”其實還真不是。我認為,以廣播閾值配置為主,以強制廣播為輔,往往是不錯的選擇。
廣播閾值的設置,更多的是把選擇權交給Spark SQL,尤其是在AQE的機制下,動態Join策略調整需要這樣的設置在運行時做出選擇。強制廣播更多的是開發者以專家經驗去指導Spark SQL該如何選擇運行時策略。二者相輔相成,並不衝突,開發者靈活地運用就能平衡Spark SQL優化策略與專家經驗在應用中的比例。

2.4 在什麼情況下,不適合把Shuffle Joins轉換為Broadcast Joins?

不適合把 Shuffle Joins 轉換為 Broadcast Joins 的情況主要有以下幾種:
大表不適合廣播,當數據量超過廣播閾值,廣播 largeTable 會失敗並退回到 shuffle
內存限制問題,當Executor 內存不足以容納廣播數據,如果廣播表大小為 1.5GB,可能會導致頻繁的 GC,引發 OOM 錯誤,數據溢出到磁盤,性能反而更差
數據分佈不均勻時,廣播表中有大量重複鍵,廣播會導致所有 executor 都加載大量重複數據,內存浪費嚴重,可能還不如 shuffle
Join 類型限制,Spark 不支持 Broadcast Full Outer Join
網絡帶寬瓶頸,在跨可用區或跨地域集羣中網絡帶寬可能成為瓶頸,可能比 shuffle 更慢(shuffle 是 executor 間交換)
數據特徵不適合,維度表頻繁更新,每次查詢都要重新廣播整個維度表;廣播表包含複雜數據結構,序列化和反序列化開銷巨大
總結:

場景

廣播 Join 適合度

原因

小表(<10MB)

✅ 非常適合

內存開銷小

中等表(10MB-100MB)

⚠️ 需要評估

取決於集羣資源

大表(>100MB)

❌ 不適合

內存壓力大

數據傾斜嚴重

❌ 不適合

內存浪費

Full Outer Join

❌ 不支持

技術限制

高併發環境

❌ 不適合

資源競爭

網絡帶寬有限

❌ 不適合

傳輸瓶頸

// 監控廣播 join 是否合適
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "100MB")

val df = broadcast(mediumTable).join(largeTable, "id")

// 查看執行計劃
df.explain()

// 監控廣播數據大小
val broadcastSize = mediumTable.queryExecution.optimizedPlan.stats.sizeInBytes
println(s"Broadcast table size: ${broadcastSize / 1024 / 1024} MB")

// 如果看到以下情況,應考慮使用 shuffle:
// 1. GC 時間佔比高
// 2. 執行時間比 shuffle 還長
// 3. 頻繁的 spill 到磁盤

3 累加器

3.1 閉包

理解累加器,首先理解scala的閉包。什麼是閉包?在創建函數時,如果需要捕獲自由變量,那麼包含指向被捕獲變量的引用的函數就被稱為閉包函數。在集羣中 Spark 會將對 RDD 的操作處理分解為 Tasks ,每個 Task 由 Executor 執行。而在執行之前,Spark會計算 task 的閉包(也就是 foreach() )。閉包會被序列化併發送給每個 Executor,但是發送給 Executor 的是副本,所以在 Driver 上輸出的依然是 sum 本身。如果想對 sum 變量進行更新,則就要用到接下來我們要講的累加器。

3.2 累加器變量

在spark應用程序中,我們經常會有這樣的需求,如異常監控,調試,記錄符合某特性的數據的數目,這種需求都需要用到計數器,如果一個變量不被聲明為一個累加器,那麼它將在被改變時不會再driver端進行全局彙總,即在分佈式運行時每個task運行的只是原始變量的一個副本,並不能改變原始變量的值,但是當這個變量被聲明為累加器後,該變量就會有分佈式計數的功能。

3.3 如何使用累加器

Spark內置了三種類型的Accumulator,分別是LongAccumulator用來累加整數型,DoubleAccumulator用來累加浮點型,CollectionAccumulator用來累加集合元素。

3.4 自定義累加器

自定義累加器類型的功能在 1.x 版本中就已經提供了,但是使用起來比較麻煩,在 Spark 2.0.0 版本後,累加器的易用性有了較大的改進,而且官方還提供了一個新的抽象類:AccumulatorV2 來提供更加友好的自定義類型累加器的實現方式。官方同時給出了一個實現的示例:CollectionAccumulator,這個類允許以集合的形式收集 Spark 應用執行過程中的一些信息。例如,我們可以用這個類收集 Spark 處理數據過程中的非法數據或者引起異常的異常數據,這對我們處理異常時很有幫助。當然,由於累加器的值最終要匯聚到 Driver 端,為了避免 Driver 端的出現 OOM,需要收集的數據規模不宜過大。
實現自定義類型累加器需要繼承 AccumulatorV2 並覆蓋下面幾個方法:

  • reset 將累加器重置為零
  • add 將另一個值添加到累加器中
  • merge 將另一個相同類型的累加器合併到該累加器中
    下面這個累加器可以用於在程序運行過程中收集一些異常或者非法數據,最終以 List[String] 的形式返回:
package com.sjf.open.spark;

import com.google.common.collect.Lists;
import org.apache.spark.util.AccumulatorV2;

import java.util.ArrayList;
import java.util.List;

/**
 * 自定義累加器 CollectionAccumulator
 * @author sjf0115
 * @Date Created in 下午2:11 18-6-4
 */
public class CollectionAccumulator<T> extends AccumulatorV2<T, List<T>> {

    private List<T> list = Lists.newArrayList();

    @Override
    public boolean isZero() {
        return list.isEmpty();
    }

    @Override
    public AccumulatorV2<T, List<T>> copy() {
        CollectionAccumulator<T> accumulator = new CollectionAccumulator<>();
        synchronized (accumulator) {
            accumulator.list.addAll(list);
        }
        return accumulator;
    }

    @Override
    public void reset() {
        list.clear();
    }

    @Override
    public void add(T v) {
        list.add(v);
    }

    @Override
    public void merge(AccumulatorV2<T, List<T>> other) {
        if(other instanceof CollectionAccumulator){
            list.addAll(((CollectionAccumulator) other).list);
        }
        else {
            throw new UnsupportedOperationException("Cannot merge " + this.getClass().getName() + " with " + other.getClass().getName());
        }
    }

    @Override
    public List<T> value() {
        return new ArrayList<>(list);
    }
}

下面我們在數據處理過程中收集非法座標為例,來看一下我們自定義的累加器如何使用:

package com.sjf.open.spark;

import com.google.common.collect.Lists;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.VoidFunction;

import java.io.Serializable;
import java.util.List;

/**
 * 自定義累加器示例
 * @author sjf0115
 * @Date Created in 下午2:11 18-6-4
 */
public class CustomAccumulatorExample implements Serializable{

    public static void main(String[] args) {
        String appName = "CustomAccumulatorExample";
        SparkConf conf = new SparkConf().setAppName(appName);
        JavaSparkContext sparkContext = new JavaSparkContext(conf);

        List<String> list = Lists.newArrayList();
        list.add("27.34832,111.32135");
        list.add("34.88478,185.17841");
        list.add("39.92378,119.50802");
        list.add("94,119.50802");

        CollectionAccumulator<String> collectionAccumulator = new CollectionAccumulator<>();
        sparkContext.sc().register(collectionAccumulator, "Illegal Coordinates");
        // 原始座標
        JavaRDD<String> sourceRDD = sparkContext.parallelize(list);
        // 過濾非法座標
        JavaRDD<String> resultRDD = sourceRDD.filter(new Function<String, Boolean>() {
            @Override
            public Boolean call(String str) throws Exception {
                String[] coordinate = str.split(",");
                double lat = Double.parseDouble(coordinate[0]);
                double lon = Double.parseDouble(coordinate[1]);
                if(Math.abs(lat) > 90 || Math.abs(lon) > 180){
                    collectionAccumulator.add(str);
                    return true;
                }
                return false;
            }
        });
        // 輸出
        resultRDD.foreach(new VoidFunction<String>() {
            @Override
            public void call(String coordinate) throws Exception {
                System.out.println("[Data]" + coordinate);
            }
        });
        // 查看異常座標
        for (String coordinate : collectionAccumulator.value()) {
            System.out.println("[Illegal]: " + coordinate);
        }
    }

}

3.5 累加器陷阱

Spark 中的一系列 transformation 操作會構成一個任務鏈,需要通過 action 操作來觸發。累加器也是一樣的,也只能通過 action 觸發更新,所以在 action 操作之前調用 value 方法查看其數值是沒有任何變化的。對於在 action 中更新的累加器,Spark 會保證每個任務對累加器只更新一次,即使重新啓動的任務也不會重新更新該值。而如果在 transformation 中更新的累加器,如果任務或作業 stage 被重新執行,那麼其對累加器的更新可能會執行多次。解決辦法是採用cache打斷任務依賴。

val acc = sc.longAccumulator("count")

// 創建 RDD
val rdd = sc.parallelize(1 to 10)

// transformation - 不會立即執行
val mappedRDD = rdd.map { x =>
  acc.add(1)
  x * 2
}

// 每次 action 都會導致 transformation 重新執行!
mappedRDD.count()  // acc = 10
mappedRDD.collect()  // acc = 20(又執行了一次!)

println(acc.value)  // 輸出 20,不是 10!