「赤兔」Chitu 框架深度解讀(十二):分佈式並行初始化與管理
大模型訓練和推理通常需要在多個設備(GPU/NPU)上並行進行。「赤兔」Chitu 框架支持多種並行策略,包括張量並行 (TP)、流水線並行 (PP)、數據並行 (DP) 和專家並行 (EP)。其分佈式並行環境的初始化和管理由 distributed/parallel_state.py 和 distributed/comm_group.py 模塊負責。
核心概念:CommGroup
distributed/comm_group.py 定義了 CommGroup 類,它是對 PyTorch ProcessGroup 的封裝和擴展。
- 初始化:
CommGroup根據傳入的rank_list(一個包含多個 Rank 列表的列表,每個子列表代表一個通信組) 和當前進程的全局rank來創建對應的ProcessGroup。 - 關鍵屬性:
group: 底層的 PyTorchProcessGroup。cpu_group: 對應的 CPUProcessGroup(用於 CPU 上的集合通信)。ranks_in_group: 當前CommGroup包含的所有 Rank 列表。group_size: 當前進程所在通信組的大小。rank_in_group: 當前進程在所在通信組內的局部 Rank。is_first_rank/is_last_rank: 判斷當前進程是否是組內的第一個/最後一個 Rank。
- 通信操作封裝: 提供了對
torch.distributed常用通信原語(如broadcast,all_reduce,all_gather,reduce_scatter等)的封裝,自動傳入正確的group參數。
CommGroup 的設計簡化了在不同並行維度上進行通信的操作,使得上層代碼無需手動管理多個 ProcessGroup 對象。
並行狀態管理 (parallel_state.py)
distributed/parallel_state.py 負責初始化和維護不同並行維度的 CommGroup 實例,並提供全局訪問接口。
- 全局變量: 定義了
_WORLD_GROUP,_TP_GROUP,_PP_GROUP,_DP_GROUP,_EP_GROUP等全局變量,用於存儲各個並行維度的CommGroup實例。 - 初始化函數 (
initialize_parallel_groups): 這是並行設置的核心入口。
- 輸入: TP, PP, DP, EP 的大小 (
tp_size,pp_size,dp_size,ep_size)。 - 獲取環境信息: 獲取全局
rank,local_rank,world_size。 - 按序初始化: 依次調用
initialize_world_group,initialize_tp_group,initialize_pp_group,initialize_dp_group,initialize_ep_group。 - 初始化邏輯: 每個
initialize_*_group函數根據並行維度的大小和當前rank計算出該維度對應的rank_list,然後創建CommGroup實例並賦值給相應的全局變量。例如:
initialize_tp_group:world_size被劃分為world_size // tp_size個 TP 組,每個組包含tp_size個連續的 Rank。initialize_pp_group:world_size被劃分為world_size // pp_size個 PP 組,每個組包含跨 TP 和 DP 維度、間隔為num_pp_groups的 Rank。initialize_dp_group: 類似 PP 組的劃分方式。initialize_ep_group: 邏輯稍複雜。如果ep_size > 1:
- 若
tp_size == ep_size且dp_size == 1,則 EP 組直接複用 TP 組 (_EP_GROUP = _TP_GROUP)。 - 若
dp_size == ep_size且tp_size == 1,則 EP 組直接複用 DP 組 (_EP_GROUP = _DP_GROUP)。 - 否則,創建新的 EP 通信組,通常是連續的 Rank 組成。
- 如果
ep_size == 1,則每個 Rank 自己構成一個 EP 組。
- 特殊處理:
initialize_pp_group中包含了針對 Ascend NPU 的特殊處理,為流水線相鄰 Stage 之間創建了額外的 Pair Group (_PP_PAIR_GROUP_DICT),可能是為了優化 P2P 通信。
- 訪問接口: 提供
get_world_group(),get_tp_group(),get_pp_group(),get_dp_group(),get_ep_group(),get_tp_size(),get_dp_size(),get_ep_size()等函數,方便全局訪問並行狀態信息和通信組。 - 銷燬:
destroy_parallel_groups()負責銷燬創建的通信組。
使用流程
- 在程序啓動時,根據配置確定 TP, PP, DP, EP 的大小。
- 調用
initialize_parallel_groups初始化所有並行通信組。 - 在模型代碼或算子實現中,通過
get_tp_group(),get_ep_group()等接口獲取相應的CommGroup。 - 調用
CommGroup實例提供的通信方法(如tp_group.all_reduce(tensor))執行集合通信。
總結
「赤兔」的分佈式並行管理模塊設計清晰,通過 CommGroup 封裝了底層的通信細節,並通過 parallel_state 模塊提供了統一的初始化入口和全局訪問接口。這種設計使得在代碼中實現和管理複雜的混合並行策略(如 TP+PP+DP+EP)變得更加方便和規範。對 EP 組複用 TP/DP 組以及為 NPU 創建 PP Pair Group 的特殊處理,也體現了其在特定場景下的優化考慮。# 「赤兔」Chitu 框架深度解讀(十二):分佈式並行初始化與管理
大模型訓練和推理通常需要在多個設備(GPU/NPU)上並行進行。「赤兔」Chitu 框架支持多種並行策略,包括張量並行 (TP)、流水線並行 (PP)、數據並行 (DP) 和專家並行 (EP)。其分佈式並行環境的初始化和管理由 distributed/parallel_state.py 和 distributed/comm_group.py 模塊負責。
核心概念:CommGroup
distributed/comm_group.py 定義了 CommGroup 類,它是對 PyTorch ProcessGroup 的封裝和擴展。
- 初始化:
CommGroup根據傳入的rank_list(一個包含多個 Rank 列表的列表,每個子列表代表一個通信組) 和當前進程的全局rank來創建對應的ProcessGroup。 - 關鍵屬性:
group: 底層的 PyTorchProcessGroup。cpu_group: 對應的 CPUProcessGroup(用於 CPU 上的集合通信)。ranks_in_group: 當前CommGroup包含的所有 Rank 列表。group_size: 當前進程所在通信組的大小。rank_in_group: 當前進程在所在通信組內的局部 Rank。is_first_rank/is_last_rank: 判斷當前進程是否是組內的第一個/最後一個 Rank。
- 通信操作封裝: 提供了對
torch.distributed常用通信原語(如broadcast,all_reduce,all_gather,reduce_scatter等)的封裝,自動傳入正確的group參數。
CommGroup 的設計簡化了在不同並行維度上進行通信的操作,使得上層代碼無需手動管理多個 ProcessGroup 對象。
並行狀態管理 (parallel_state.py)
distributed/parallel_state.py 負責初始化和維護不同並行維度的 CommGroup 實例,並提供全局訪問接口。
- 全局變量: 定義了
_WORLD_GROUP,_TP_GROUP,_PP_GROUP,_DP_GROUP,_EP_GROUP等全局變量,用於存儲各個並行維度的CommGroup實例。 - 初始化函數 (
initialize_parallel_groups): 這是並行設置的核心入口。
- 輸入: TP, PP, DP, EP 的大小 (
tp_size,pp_size,dp_size,ep_size)。 - 獲取環境信息: 獲取全局
rank,local_rank,world_size。 - 按序初始化: 依次調用
initialize_world_group,initialize_tp_group,initialize_pp_group,initialize_dp_group,initialize_ep_group。 - 初始化邏輯: 每個
initialize_*_group函數根據並行維度的大小和當前rank計算出該維度對應的rank_list,然後創建CommGroup實例並賦值給相應的全局變量。例如:
initialize_tp_group:world_size被劃分為world_size // tp_size個 TP 組,每個組包含tp_size個連續的 Rank。initialize_pp_group:world_size被劃分為world_size // pp_size個 PP 組,每個組包含跨 TP 和 DP 維度、間隔為num_pp_groups的 Rank。initialize_dp_group: 類似 PP 組的劃分方式。initialize_ep_group: 邏輯稍複雜。如果ep_size > 1:
- 若
tp_size == ep_size且dp_size == 1,則 EP 組直接複用 TP 組 (_EP_GROUP = _TP_GROUP)。 - 若
dp_size == ep_size且tp_size == 1,則 EP 組直接複用 DP 組 (_EP_GROUP = _DP_GROUP)。 - 否則,創建新的 EP 通信組,通常是連續的 Rank 組成。
- 如果
ep_size == 1,則每個 Rank 自己構成一個 EP 組。
- 特殊處理:
initialize_pp_group中包含了針對 Ascend NPU 的特殊處理,為流水線相鄰 Stage 之間創建了額外的 Pair Group (_PP_PAIR_GROUP_DICT),可能是為了優化 P2P 通信。
- 訪問接口: 提供
get_world_group(),get_tp_group(),get_pp_group(),get_dp_group(),get_ep_group(),get_tp_size(),get_dp_size(),get_ep_size()等函數,方便全局訪問並行狀態信息和通信組。 - 銷燬:
destroy_parallel_groups()負責銷燬創建的通信組。
使用流程
- 在程序啓動時,根據配置確定 TP, PP, DP, EP 的大小。
- 調用
initialize_parallel_groups初始化所有並行通信組。 - 在模型代碼或算子實現中,通過
get_tp_group(),get_ep_group()等接口獲取相應的CommGroup。 - 調用
CommGroup實例提供的通信方法(如tp_group.all_reduce(tensor))執行集合通信。
總結
「赤兔」的分佈式並行管理模塊設計清晰,通過 CommGroup 封裝了底層的通信細節,並通過 parallel_state 模塊提供了統一的初始化入口和全局訪問接口。這種設計使得在代碼中實現和管理複雜的混合並行策略(如 TP+PP+DP+EP)變得更加方便和規範。對 EP 組複用 TP/DP 組以及為 NPU 創建 PP Pair Group 的特殊處理,也體現了其在特定場景下的優化考慮。