「赤兔」Chitu 框架深度解讀(十二):分佈式並行初始化與管理

大模型訓練和推理通常需要在多個設備(GPU/NPU)上並行進行。「赤兔」Chitu 框架支持多種並行策略,包括張量並行 (TP)、流水線並行 (PP)、數據並行 (DP) 和專家並行 (EP)。其分佈式並行環境的初始化和管理由 distributed/parallel_state.pydistributed/comm_group.py 模塊負責。

核心概念:CommGroup

distributed/comm_group.py 定義了 CommGroup 類,它是對 PyTorch ProcessGroup 的封裝和擴展。

  • 初始化: CommGroup 根據傳入的 rank_list (一個包含多個 Rank 列表的列表,每個子列表代表一個通信組) 和當前進程的全局 rank 來創建對應的 ProcessGroup
  • 關鍵屬性:
  • group: 底層的 PyTorch ProcessGroup
  • cpu_group: 對應的 CPU ProcessGroup (用於 CPU 上的集合通信)。
  • ranks_in_group: 當前 CommGroup 包含的所有 Rank 列表。
  • group_size: 當前進程所在通信組的大小。
  • rank_in_group: 當前進程在所在通信組內的局部 Rank。
  • is_first_rank/is_last_rank: 判斷當前進程是否是組內的第一個/最後一個 Rank。
  • 通信操作封裝: 提供了對 torch.distributed 常用通信原語(如 broadcast, all_reduce, all_gather, reduce_scatter 等)的封裝,自動傳入正確的 group 參數。

CommGroup 的設計簡化了在不同並行維度上進行通信的操作,使得上層代碼無需手動管理多個 ProcessGroup 對象。

並行狀態管理 (parallel_state.py)

distributed/parallel_state.py 負責初始化和維護不同並行維度的 CommGroup 實例,並提供全局訪問接口。

  • 全局變量: 定義了 _WORLD_GROUP, _TP_GROUP, _PP_GROUP, _DP_GROUP, _EP_GROUP 等全局變量,用於存儲各個並行維度的 CommGroup 實例。
  • 初始化函數 (initialize_parallel_groups): 這是並行設置的核心入口。
  • 輸入: TP, PP, DP, EP 的大小 (tp_size, pp_size, dp_size, ep_size)。
  • 獲取環境信息: 獲取全局 rank, local_rank, world_size
  • 按序初始化: 依次調用 initialize_world_group, initialize_tp_group, initialize_pp_group, initialize_dp_group, initialize_ep_group
  • 初始化邏輯: 每個 initialize_*_group 函數根據並行維度的大小和當前 rank 計算出該維度對應的 rank_list,然後創建 CommGroup 實例並賦值給相應的全局變量。例如:
  • initialize_tp_group: world_size 被劃分為 world_size // tp_size 個 TP 組,每個組包含 tp_size 個連續的 Rank。
  • initialize_pp_group: world_size 被劃分為 world_size // pp_size 個 PP 組,每個組包含跨 TP 和 DP 維度、間隔為 num_pp_groups 的 Rank。
  • initialize_dp_group: 類似 PP 組的劃分方式。
  • initialize_ep_group: 邏輯稍複雜。如果 ep_size > 1
  • tp_size == ep_sizedp_size == 1,則 EP 組直接複用 TP 組 (_EP_GROUP = _TP_GROUP)。
  • dp_size == ep_sizetp_size == 1,則 EP 組直接複用 DP 組 (_EP_GROUP = _DP_GROUP)。
  • 否則,創建新的 EP 通信組,通常是連續的 Rank 組成。
  • 如果 ep_size == 1,則每個 Rank 自己構成一個 EP 組。
  • 特殊處理: initialize_pp_group 中包含了針對 Ascend NPU 的特殊處理,為流水線相鄰 Stage 之間創建了額外的 Pair Group (_PP_PAIR_GROUP_DICT),可能是為了優化 P2P 通信。
  • 訪問接口: 提供 get_world_group(), get_tp_group(), get_pp_group(), get_dp_group(), get_ep_group(), get_tp_size(), get_dp_size(), get_ep_size() 等函數,方便全局訪問並行狀態信息和通信組。
  • 銷燬: destroy_parallel_groups() 負責銷燬創建的通信組。

使用流程

  1. 在程序啓動時,根據配置確定 TP, PP, DP, EP 的大小。
  2. 調用 initialize_parallel_groups 初始化所有並行通信組。
  3. 在模型代碼或算子實現中,通過 get_tp_group(), get_ep_group() 等接口獲取相應的 CommGroup
  4. 調用 CommGroup 實例提供的通信方法(如 tp_group.all_reduce(tensor))執行集合通信。

總結

「赤兔」的分佈式並行管理模塊設計清晰,通過 CommGroup 封裝了底層的通信細節,並通過 parallel_state 模塊提供了統一的初始化入口和全局訪問接口。這種設計使得在代碼中實現和管理複雜的混合並行策略(如 TP+PP+DP+EP)變得更加方便和規範。對 EP 組複用 TP/DP 組以及為 NPU 創建 PP Pair Group 的特殊處理,也體現了其在特定場景下的優化考慮。# 「赤兔」Chitu 框架深度解讀(十二):分佈式並行初始化與管理

大模型訓練和推理通常需要在多個設備(GPU/NPU)上並行進行。「赤兔」Chitu 框架支持多種並行策略,包括張量並行 (TP)、流水線並行 (PP)、數據並行 (DP) 和專家並行 (EP)。其分佈式並行環境的初始化和管理由 distributed/parallel_state.pydistributed/comm_group.py 模塊負責。

核心概念:CommGroup

distributed/comm_group.py 定義了 CommGroup 類,它是對 PyTorch ProcessGroup 的封裝和擴展。

  • 初始化: CommGroup 根據傳入的 rank_list (一個包含多個 Rank 列表的列表,每個子列表代表一個通信組) 和當前進程的全局 rank 來創建對應的 ProcessGroup
  • 關鍵屬性:
  • group: 底層的 PyTorch ProcessGroup
  • cpu_group: 對應的 CPU ProcessGroup (用於 CPU 上的集合通信)。
  • ranks_in_group: 當前 CommGroup 包含的所有 Rank 列表。
  • group_size: 當前進程所在通信組的大小。
  • rank_in_group: 當前進程在所在通信組內的局部 Rank。
  • is_first_rank/is_last_rank: 判斷當前進程是否是組內的第一個/最後一個 Rank。
  • 通信操作封裝: 提供了對 torch.distributed 常用通信原語(如 broadcast, all_reduce, all_gather, reduce_scatter 等)的封裝,自動傳入正確的 group 參數。

CommGroup 的設計簡化了在不同並行維度上進行通信的操作,使得上層代碼無需手動管理多個 ProcessGroup 對象。

並行狀態管理 (parallel_state.py)

distributed/parallel_state.py 負責初始化和維護不同並行維度的 CommGroup 實例,並提供全局訪問接口。

  • 全局變量: 定義了 _WORLD_GROUP, _TP_GROUP, _PP_GROUP, _DP_GROUP, _EP_GROUP 等全局變量,用於存儲各個並行維度的 CommGroup 實例。
  • 初始化函數 (initialize_parallel_groups): 這是並行設置的核心入口。
  • 輸入: TP, PP, DP, EP 的大小 (tp_size, pp_size, dp_size, ep_size)。
  • 獲取環境信息: 獲取全局 rank, local_rank, world_size
  • 按序初始化: 依次調用 initialize_world_group, initialize_tp_group, initialize_pp_group, initialize_dp_group, initialize_ep_group
  • 初始化邏輯: 每個 initialize_*_group 函數根據並行維度的大小和當前 rank 計算出該維度對應的 rank_list,然後創建 CommGroup 實例並賦值給相應的全局變量。例如:
  • initialize_tp_group: world_size 被劃分為 world_size // tp_size 個 TP 組,每個組包含 tp_size 個連續的 Rank。
  • initialize_pp_group: world_size 被劃分為 world_size // pp_size 個 PP 組,每個組包含跨 TP 和 DP 維度、間隔為 num_pp_groups 的 Rank。
  • initialize_dp_group: 類似 PP 組的劃分方式。
  • initialize_ep_group: 邏輯稍複雜。如果 ep_size > 1
  • tp_size == ep_sizedp_size == 1,則 EP 組直接複用 TP 組 (_EP_GROUP = _TP_GROUP)。
  • dp_size == ep_sizetp_size == 1,則 EP 組直接複用 DP 組 (_EP_GROUP = _DP_GROUP)。
  • 否則,創建新的 EP 通信組,通常是連續的 Rank 組成。
  • 如果 ep_size == 1,則每個 Rank 自己構成一個 EP 組。
  • 特殊處理: initialize_pp_group 中包含了針對 Ascend NPU 的特殊處理,為流水線相鄰 Stage 之間創建了額外的 Pair Group (_PP_PAIR_GROUP_DICT),可能是為了優化 P2P 通信。
  • 訪問接口: 提供 get_world_group(), get_tp_group(), get_pp_group(), get_dp_group(), get_ep_group(), get_tp_size(), get_dp_size(), get_ep_size() 等函數,方便全局訪問並行狀態信息和通信組。
  • 銷燬: destroy_parallel_groups() 負責銷燬創建的通信組。

使用流程

  1. 在程序啓動時,根據配置確定 TP, PP, DP, EP 的大小。
  2. 調用 initialize_parallel_groups 初始化所有並行通信組。
  3. 在模型代碼或算子實現中,通過 get_tp_group(), get_ep_group() 等接口獲取相應的 CommGroup
  4. 調用 CommGroup 實例提供的通信方法(如 tp_group.all_reduce(tensor))執行集合通信。

總結

「赤兔」的分佈式並行管理模塊設計清晰,通過 CommGroup 封裝了底層的通信細節,並通過 parallel_state 模塊提供了統一的初始化入口和全局訪問接口。這種設計使得在代碼中實現和管理複雜的混合並行策略(如 TP+PP+DP+EP)變得更加方便和規範。對 EP 組複用 TP/DP 組以及為 NPU 創建 PP Pair Group 的特殊處理,也體現了其在特定場景下的優化考慮。