引言
在當今大語言模型(LLM)的浪潮中,模型規模的持續擴張是提升性能的關鍵驅動力。然而,隨着模型參數量的激增,訓練和推理的計算成本也隨之飆升。為了解決這一挑戰,混合專家模型(Mixture-of-Experts, MoE)架構應運而生,並已成為許多前沿模型(如 Mixtral 8x7B, DeepSeek-V3)的核心技術之一。
MoE 的核心思想非常巧妙:它不再要求模型的每一部分處理所有的輸入數據,而是引入了多個“專家”子網絡——這些專家本質上就是標準的前饋網絡。通過一個門控網絡(Gating Network),系統可以為每個輸入(token)動態地、稀疏地選擇並激活一小部分專家進行計算。這樣,模型可以在總參數量巨大的同時,保持單次前向傳播的計算量相對可控,實現了規模與效率的精妙平衡。
MoE 層的完整工作流包含“路由選擇”和“專家計算”兩個核心環節,本文將聚焦於後者。這裏首先從一個直觀易懂的樸素實現入手,來展示 MoE 專家計算的完整邏輯,也為理解後續內容打下基礎。本系列文章還將介紹 Fused MoE, 它是一種通過算子融合進行優化的方案,將多個獨立的計算合併成一個單一的批處理操作,從而加速 MoE 中的專家計算環節。
通過對比這兩種實現,希望能讓讀者深刻理解 Fused MoE 算子的原理。
第一部分:MoE 專家計算的直觀實現
在 MoE 模型中,一旦門控網絡為每個輸入的 token 分配好了對應的專家,接下來的任務就是執行計算。本部分將介紹一種直觀的、基於循環的實現方式。這種方式雖然不是性能最高的,但其邏輯非常清晰,是理解 MoE 工作原理的絕佳起點。
1. 基本設定:從 MLP 到 Expert
首先需要明確 MoE 中的“專家”(Expert)到底是什麼。本質上,一個專家模塊就是一個標準的前饋神經網絡(Feed-Forward Network, FFN),或者説多層感知機(MLP)。在 Transformer 架構中,它通常由兩個或三個線性層組成,負責對 token 的特徵進行非線性變換。
一個 MoE 層正是由多個這樣的 Expert 模塊以及一個負責調度的 gate 模塊組成的集合。
以下是 DeepSeek-V3 中 MoE 模塊的完整實現[1],可以瞭解其具體結構。
class Expert(nn.Module):
2. 基於循環的計算流程
這裏重點分析 MoE 模塊中 forward 方法的核心邏輯。在本文的討論中會忽略共享專家(shared\_experts)和分佈式通信(dist.all\_reduce)部分,專注於專家的計算。
其計算流程可以拆解如下:
- 輸入整形與路由:
- x = x.view(-1, self.dim): 將輸入張量 x 整形為二維,形狀為 [token總數, 特徵維度]。
- weights, indices = self.gate(x): 調用門控網絡,得到每個 token 對應的專家權重 weights 和專家索引 indices。
2. 專家計算循環:
- y = torch.zeros_like(x): 初始化一個全零的輸出張量 y,用於累加結果。
- counts = torch.bincount(...): 這是一個優化。它會統計每個專家被分配了多少個 token。
- for i in range(...): 核心的 for 循環,遍歷當前設備上的所有專家。
- if counts[i] == 0: continue: 如果一個專家沒有被分配任何 token,則直接跳過,避免不必要的計算。
- idx, top = torch.where(indices == i): 查找所有被分配給當前專家 i 的 token 的索引。
- y[idx] += expert(x[idx]) * weights[idx, top, None]: 這是最關鍵的一步。它執行以下操作:
- x[idx]: 篩選出需要由專家 i 處理的 token。
- expert(...): 將這些 token 送入專家網絡進行計算。
- * weights[...]: 將專家輸出與對應的權重相乘。
- y[idx] += ...: 將加權後的結果累加到輸出張量 y 的相應位置。
這種實現方式可以被稱為“專家視角”的計算模式:它的主邏輯是“對於我這個專家,有哪些 token 需要我來處理?”。它一步步地完成計算,邏輯非常清晰。
3. 實例演算:一步步看懂計算過程
為了讓整個過程更具體,這裏用一個完整的例子來手動演算一遍。詳細展示 Expert 內部的每一次矩陣乘法。
設定參數:
- 輸入 token 數量 M = 2
- Token 維度 dim = 3
- 專家中間層維度 inter_dim = 2
- 總專家數量 E = 4
- 激活專家數 topk = 2
預設輸入和路由結果:
- 輸入 x (形狀 [2, 3]):
[[1, 1, 1],
專家索引 indices (形狀 [2, 2]):
[[0, 2], # Token 0 -> Expert 0, Expert 2
- 專家權重 weights (形狀 [2, 2]):
[[0.5, 0.5],
專家參數設定:
我們假設4個專家的所有權重矩陣(w1, w2, w3)都用一個固定的值填充。
- w1 和 w3 的權重矩陣形狀為 [inter_dim, dim] 即 [2, 3]。
- w2 的權重矩陣形狀為 [dim, inter_dim] 即 [3, 2]。
Expert 0 (所有參數為 1):
- W1_0 = [[1, 1, 1], [1, 1, 1]]
- W3_0 = [[1, 1, 1], [1, 1, 1]]
- W2_0 = [[1, 1], [1, 1], [1, 1]]
Expert 1 (所有參數為 2):
- W1_1 = [[2, 2, 2], [2, 2, 2]]
- W3_1 = [[2, 2, 2], [2, 2, 2]]
- W2_1 = [[2, 2], [2, 2], [2, 2]]
Expert 2 (所有參數為 3):
- W1_2 = [[3, 3, 3], [3, 3, 3]]
- W3_2 = [[3, 3, 3], [3, 3, 3]]
- W2_2 = [[3, 3], [3, 3], [3, 3]]
Expert 3 (所有參數為 4):
- W1_3 = [[4, 4, 4], [4, 4, 4]]
- W3_3 = [[4, 4, 4], [4, 4, 4]]
- W2_3 = [[4, 4], [4, 4], [4, 4]]
演算流程:
1. 初始化:
- y 被初始化為 [[0, 0, 0], [0, 0, 0]]。
2. for i = 0 (處理 Expert 0):
- torch.where(indices == 0) 找到 idx = [0]。輸入為 x_in = x[0] = [1, 1, 1]。
- Expert 0 內部計算:
- h1 = x\_in @ W1\_0.T = [3, 3]
- h3 = x\_in @ W3\_0.T = [3, 3]
- silu(h1) = [2.8577, 2.8577]
- combined = silu(h1) * h3 = [8.5732, 8.5732]
- output = combined @ W2_0.T = [17.1463, 17.1463, 17.1463]
- 加權並累加:
- weighted_output = output * 0.5 = [8.5732, 8.5732, 8.5732]
- y[0] += weighted_output
- y 變為 [[8.5732, 8.5732, 8.5732], [0, 0, 0]]。
3. for i = 1 (處理 Expert 1):
- torch.where(indices == 1) 沒有找到匹配項,跳過。
4. for i = 2 (處理 Expert 2):
- torch.where(indices == 2) 找到 idx = [0, 1]。該專家需處理兩個 token。
- 處理 Token 0 (x_in = x[0] = [1, 1, 1]):
- h1 = x\_in @ W1\_2.T = [9, 9]
- h3 = x\_in @ W3\_2.T = [9, 9]
- silu(h1) = [8.9989, 8.9989]
- combined = silu(h1) * h3 = [80.9900, 80.9900]
- output0 = combined @ W2_2.T = [485.9400, 485.9400, 485.9400]
- 處理 Token 1 (x_in = x[1] = [2, 2, 2]):
- h1 = x\_in @ W1\_2.T = [18, 18]
- h3 = x\_in @ W3\_2.T = [18, 18]
- silu(h1) = [18.0000, 18.0000]
- combined = silu(h1) * h3 = [324.0000, 324.0000]
- output1 = combined @ W2_2.T = [1944.0000, 1944.0000, 1944.0000]
- 加權並累加:
- weighted0 = output0 * 0.5 = [242.9700, 242.9700, 242.9700]
- weighted1 = output1 * 0.5 = [972.0000, 972.0000, 972.0000]
- y[0] += weighted0: y[0] 變為 [8.5732, 8.5732, 8.5732] + [242.9700, 242.9700, 242.9700] = [251.5432, 251.5432, 251.5432]
- y[1] += weighted1: y[1] 變為 [972.0000, 972.0000, 972.0000]
5. for i = 3 (處理 Expert 3):
- torch.where(indices == 3) 找到 idx = [1]。輸入為 x_in = x[1] = [2, 2, 2]。
- Expert 3 內部計算:
- h1 = x\_in @ W1\_3.T = [24, 24]
- h3 = x\_in @ W3\_3.T = [24, 24]
- silu(h1) = [24.0000, 24.0000]
- combined = silu(h1) * h3 = [576.0000, 576.0000]
- output = combined @ W2_3.T = [4608.0000, 4608.0000, 4608.0000]
- 加權並累加:
- weighted_output = output * 0.5 = [2304.0000, 2304.0000, 2304.0000]
- y[1] += weighted_output: y[1] 變為 [972.0000, 972.0000, 972.0000] + [2304.0000, 2304.0000, 2304.0000] = [3276.0000, 3276.0000, 3276.0000]
最終結果:
演算結束後,輸出張量 y 的值為:
[[251.5432, 251.5432, 251.5432],
通過這個極其詳細的數值演算,可以看到 token 是如何被分發和結果如何被彙總的,並瞭解在每個專家內部,數據是如何經過一系列線性變換和激活函數處理的。
總結
本節詳細介紹了一種基於循環的樸素實現。這種實現方式的優點在於邏輯清晰、直觀易懂。這種實現方式採取的是一種“專家視角”的計算模式,也就是按順序遍歷每一個專家,併為該專家篩選出所有分配給它的 token 並執行計算,最終將結果加權累積到對應的輸出位置。
本節還通過代碼分析和詳盡的實例演算一步步地展示了計算的全過程。
這個樸素的實現雖然不是為高性能而設計的,但它為我們理解更復雜的優化算子,如下一部分將要介紹的 Fused MoE,提供了一個不可或缺的、關於“計算正確性”的基準。我們必須先知道要計算“什麼”,才能更好地探討如何“更快地”計算。
展望
至此,我們通過對一個樸素、基於循環的 MoE 實現方案的分析,可以清晰地理解 MoE 計算的內在邏輯與最終目標。
然而,這個實現在邏輯清晰的同時,卻也暴露了其計算模式是由大量獨立、小規模的計算組成的。而現代 GPU 通常是為了進行大規模並行計算而設計的,其峯值性能只有在處理大型、連續的數據塊時才能被充分激發。這也正是下一部分將要深入探討的 Fused MoE 方案所要解決的核心問題。
參考鏈接:
[1]:https://github.com/deepseek-ai/DeepSeek-V3/blob/9b4e9788e4a3a731f7567338ed15d3ec549ce03b/inference/model.py#L636
相關文章推薦
Triton 實戰:從零開始構建一個 GPU 序列化算子-基礎實現
blue-rdma 設計介紹 (三)—— 數據包處理
虛擬 RDMA 設備驅動實現(一):環境配置與Linux內核模塊初探
達坦科技始終致力於打造高性能AI+Cloud基礎設施平台,積極推動AI應用的落地。達坦科技通過軟硬件深度融合的方式,提供AI推理引擎和高性能網絡,為AI應用提供彈性、便利、經濟的基礎設施服務,以此滿足不同行業客户對AI+Cloud的需求。
公眾號:達坦科技DatenLord
DatenLord官網:
https://datenlord.github.io/zh-cn/
知乎賬號:
https://www.zhihu.com/org/da-tan-ke-ji
B站:
https://space.bilibili.com/2017027518
郵箱:info@datenlord.com
如果您有興趣加入達坦科技Rust前沿技術交流羣、硬件敏捷開發和驗證方法學討論羣或AI Infra交流羣,請添加小助手微信:DatenLord_Tech