Every Token Counts: Generalizing 16M Ultra-Long Context in Large Language Models(超長文本模型論文HSA)
這篇論文介紹了 HSA-UltraLong,這是一個基於 分層稀疏注意力(Hierarchical Sparse Attention, HSA) 機制的模型。該模型能夠在僅使用 32K 長度進行訓練的情況下,成功將上下文窗口外推到 1600萬(16M)Token,並在大海撈針(NIAH)測試中保持極高的準確率 。
1. 核心內容總結
論文旨在解決 LLM 如何實現“無限”記憶的問題。作者認為,要處理超長上下文,必須滿足三個條件:稀疏性(不能全關注)、隨機訪問靈活性(能精準檢索過去的信息)和長度外推性(從短訓練推廣到長推理) 。
現有的方法(如 Mamba、線性注意力、NSA)在檢索精度或長距離外推上存在短板。HSA-UltraLong 通過結合滑動窗口注意力(SWA)和HSA,模仿了“混合專家模型(MoE)”的思路,實現了高效且精準的超長文本檢索與生成 。
2. 核心創新點
A. 類似 MoE 的分層稀疏注意力 (HSA)
這是論文最大的創新。傳統的稀疏注意力(如 NSA)通常是“先選塊,再拼接,最後算注意力”,導致選擇過程不可導(無法端到端學習)。
HSA 的做法類似於 Mixture-of-Experts (MoE):
- 分塊與路標(Landmark): 將歷史 KV Cache 分成固定長度的塊(Chunk),每個塊有一個路標向量。
- 可學習的檢索: 當前 Token 與所有歷史塊的路標計算相似度,選出 Top-K 個最相關的塊。
- 獨立注意力與融合: 對選出的 K 個塊分別計算注意力,最後根據檢索分數(Retrieval Scores)對注意力結果進行加權求和。
- 優勢: 檢索分數直接參與最終計算,因此可以通過反向傳播優化檢索能力,讓模型學會“該去哪裏找記憶”。
B. 混合位置編碼策略 (RoPE + NoPE)
- 局部滑動窗口(SWA): 使用 RoPE(旋轉位置編碼),處理短期依賴。
- 全局 HSA: 不使用位置編碼(NoPE)。
- 原因: 作者發現 RoPE 會阻礙長度外推,而不使用位置編碼有助於模型將檢索能力從短序列泛化到無限長序列。
C. “蹺蹺板”訓練效應與解決方案
作者發現滑動窗口(SWA)和 HSA 之間存在競爭關係。如果 SWA 窗口過大(如 4K),模型會偷懶,只關注局部信息,導致 HSA 學不到長距離檢索能力。
解決方案: 在預訓練/熱身階段,故意將 SWA 窗口縮小(如 512),強制模型依賴 HSA 去獲取信息,練好“內功”後再擴大窗口。
3. Python Demo 代碼助解
這個 Demo 展示了 HSA 如何像 MoE 一樣運作:計算路由分數 -> 選中 TopK 塊 -> 分別計算 Attention -> 加權融合。
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimplifiedHSA(nn.Module):
def __init__(self, d_model, chunk_size, top_k):
super().__init__()
self.d_model = d_model
self.chunk_size = chunk_size
self.top_k = top_k
# 簡單的線性層生成 Query, Key, Value
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
# 專門用於生成“塊路標(Landmark)”的投影層
# 用於計算檢索分數
self.w_landmark = nn.Linear(d_model, d_model)
self.w_retrieval_q = nn.Linear(d_model, d_model)
def forward(self, x, past_kv_raw):
"""
x: 當前輸入的 token [Batch, 1, Dim]
past_kv_raw: 歷史所有的 token 序列 [Batch, Total_Len, Dim]
"""
B, _, D = x.shape
Total_Len = past_kv_raw.shape[1]
# 1. 預處理:將歷史信息分塊 (Chunking)
# 假設 Total_Len 能被 chunk_size 整除
num_chunks = Total_Len // self.chunk_size
# 形狀變為: [Batch, Num_Chunks, Chunk_Size, Dim]
past_chunks = past_kv_raw.view(B, num_chunks, self.chunk_size, D)
# 2. 生成路標 (Landmark Generation)
# 論文中提到路標是對塊內容的摘要,這裏簡化為對塊內token取平均
# Landmark 形狀: [Batch, Num_Chunks, Dim]
chunk_landmarks = self.w_landmark(past_chunks.mean(dim=2))
# 3. 檢索 (Retrieval / Routing)
# 計算當前 token 與每個塊路標的相似度
q_retrieval = self.w_retrieval_q(x) # [Batch, 1, Dim]
# 檢索分數: [Batch, 1, Num_Chunks]
retrieval_scores = torch.matmul(q_retrieval, chunk_landmarks.transpose(1, 2)) / (D ** 0.5)
# 選出分數最高的 Top-K 個塊
topk_scores, topk_indices = torch.topk(retrieval_scores, k=self.top_k, dim=-1)
# 對 Top-K 分數進行 Softmax,作為融合的權重
# 對應論文公式 (2) 中的 w_{t,i}
routing_weights = F.softmax(topk_scores, dim=-1) # [Batch, 1, K]
# 4. 稀疏注意力 (Attention on Selected Chunks)
# 這裏的關鍵是:只對選中的 K 個塊計算 Attention
q_attn = self.w_q(x) # [Batch, 1, Dim]
# 收集選中的塊對應的 KV
# selected_chunks 形狀: [Batch, K, Chunk_Size, Dim]
# 實際代碼需要複雜的 gather 操作,這裏為了演示邏輯簡化處理
selected_chunks = self._gather_chunks(past_chunks, topk_indices)
k_attn = self.w_k(selected_chunks)
v_attn = self.w_v(selected_chunks)
# 在每個塊內部獨立計算 Attention
# Q: [B, 1, D], K_chunk: [B, Chunk_Size, D] -> Attn: [B, 1, Chunk_Size]
# 結果 outputs 形狀: [Batch, K, Dim] (每個塊貢獻一個輸出向量)
chunk_outputs = []
for k in range(self.top_k):
# 取出第 k 個選中的塊
k_c = k_attn[:, k, :, :]
v_c = v_attn[:, k, :, :]
# 標準 Attention 計算
attn_score = torch.matmul(q_attn, k_c.transpose(1, 2)) / (D ** 0.5)
attn_prob = F.softmax(attn_score, dim=-1)
out_c = torch.matmul(attn_prob, v_c) # [Batch, 1, Dim]
chunk_outputs.append(out_c)
chunk_outputs = torch.cat(chunk_outputs, dim=1) # [Batch, K, Dim]
# 5. 加權融合 (Weighted Sum Fusion)
# 利用檢索分數作為權重,融合各個塊的 Attention 結果
# weights: [Batch, 1, K], outputs: [Batch, K, Dim]
final_output = torch.matmul(routing_weights, chunk_outputs) # [Batch, 1, Dim]
return final_output
def _gather_chunks(self, all_chunks, indices):
# 輔助函數:根據索引從所有塊中提取 Top-K 個塊
B, K = indices.shape[0], indices.shape[2]
C, D = self.chunk_size, self.d_model
# 擴展 indices 以便 gather: [B, K, C, D]
# 實際工程實現通常會優化這一步以避免內存複製
gathered = torch.zeros(B, K, C, D, device=all_chunks.device)
for b in range(B):
idx = indices[b, 0, :] # [K]
gathered[b] = all_chunks[b, idx]
return gathered
# --- 模擬運行 ---
d_model = 64
chunk_size = 4 # 假設每塊4個token
top_k = 2 # 每次只看最相關的2個塊
seq_len = 20 # 歷史總長度
model = SimplifiedHSA(d_model, chunk_size, top_k)
current_token = torch.randn(1, 1, d_model)
history_tokens = torch.randn(1, seq_len, d_model)
output = model(current_token, history_tokens)
print(f"輸入歷史長度: {seq_len}, 分塊數量: {seq_len//chunk_size}")
print(f"HSA 選擇了最相關的 {top_k} 個塊進行注意力計算")
print(f"最終輸出形狀: {output.shape}")
代碼解析:
- 分塊(Chunking): 代碼將歷史 Token 強行切分為
past_chunks,模擬了論文將長序列結構化的過程。 - 路標(Landmarks):
chunk_landmarks代表了每個塊的“摘要”,用於快速檢索,避免了與所有 Token 進行點積。 - 可導的路由(Weighted Sum): 請注意最後一步
torch.matmul(routing_weights, chunk_outputs)。由於routing_weights來自retrieval_scores的 Softmax,這意味着如果模型覺得某個塊找得不對,梯度可以通過這個權重回傳,更新w_retrieval_q和w_landmark。這就是 HSA 優於 NSA(不可導選擇)的關鍵所在 。
4.分塊方式(Chunking)和路標(Landmark)生成的具體細節:
1. 分塊方式(Chunking Strategy)
不是按照 4K 分塊,而是按照 64 個 Token 分塊。
分塊大小(Chunk Size): 論文中明確指出,文本序列被劃分為固定長度的塊,默認長度為64 。
為什麼是 64? 選擇這個尺寸主要是為了更好地與硬件(如 GPU 的計算單元)對齊,以提高計算效率 。
2. 路標(Landmark)是如何生成的?
路標不僅僅是簡單的平均值,而是通過一個專門的雙向編碼器(Bi-directional Encoder) 生成的語義摘要。具體步驟如下:
- 提取中間層狀態: 模型首先從中間層(第
層) 獲取該塊內 64 個 Token 的隱藏狀態(Hidden States) 。
- 添加 [CLS] Token: 在這 64 個 Token 的狀態序列中加入一個特殊的 [CLS] Token 。
- 雙向編碼: 將這個序列(64個 Token + [CLS])輸入到一個輕量級的雙向編碼器(Bi-directional Chunk Encoder) 中。注意,這裏使用雙向注意力是為了讓路標能捕捉到塊內上下文的完整信息,而不受因果(Causal)掩碼的限制 。
- 提取摘要: 編碼器輸出中的特定部分(對應 [CLS] 的輸出 )即被作為該塊的路標(Landmark) 。
3. 路標向量的維度與相似度計算
路標(Landmark)的作用完全等同於 Key 向量,用於通過點積計算相似度(即檢索分數)。
相似度計算: 當前的 Token 會經過一個線性變換生成檢索查詢向量 ,然後與歷史塊的路標 (即上一步生成的 )進行點積(Dot Product),從而得到檢索分數(Retrieval Score) 。
- 維度: 路標向量的維度是 **(模型的隱藏層維度,Hidden Dimension),例如 4096 或 5120 。
- 是否一致: 它與模型的主隱藏層維度一致,但通常大於標準的單頭注意力 Key 向量的維度(單頭維度通常是 )。你可以將其理解為一個全局的、未切分多頭的超級 Key,專門用於粗粒度的塊級檢索。
總結流程圖解:
原始文本 -> 每 64 Token 切一塊 -> 取中間層向量 -> 加 [CLS] 進雙向編碼器 -> 產出維度為 **** 的路標向量 -> 與當前 Token 的 Query 計算點積 -> Top-K 檢索。
所以,在一個包含 HSA 的層中,實際參與注意力計算的 Token 數量(Active Context)是這樣組成的:
4. 計算公式
- SWA (局部注意力): 負責覆蓋最近的 4096 個 Token,捕捉短期依賴和語法結構 。
- HSA (全局稀疏注意力): 負責從千萬級歷史中檢索出 64 個最相關的塊,每塊 64 個 Token,共計 4096 個 Token 。
結論: 在一個 HSA 層中,無論歷史上下文是 100萬 還是 1600萬,模型實際上只需要對 8192 個 Token(4K 局部 + 4K 全局)進行注意力計算。