1. 為什麼你的模型“記性”這麼差?(痛點與背景)

想象一下,你訓練了一個神經網絡來識別手寫數字(MNIST),準確率高達 99%。

接着,你希望能複用這個聰明的腦子,讓它繼續學習識別時尚單品(Fashion-MNIST)。

你把模型拿來,在“衣服鞋子”的數據集上跑了幾輪訓練。結果很棒,它現在能完美識別運動鞋和襯衫了。

但是,當你隨手扔給它一張數字 “7” 的圖片時,它卻一臉自信地告訴你:“這是一隻靴子!”

這就是災難性遺忘(Catastrophic Forgetting)

在傳統的反向傳播中,為了讓模型適應新任務(Task B),優化器會毫不留情地修改模型裏的權重參數。它並不在乎這些參數之前對舊任務(Task A)有多重要,只要能降低 Task B 的 Loss,它就會大幅改變權重。結果就是:舊知識的“神經連接”被徹底破壞了。

EWC(Elastic Weight Consolidation) 的出現,就是為了解決這個問題。它能讓模型在學習新技能的同時,優雅地“鎖住”那些對舊技能至關重要的記憶。


2. 概念拆解:給神經元裝上“彈簧”

EWC 的論文裏充滿了費雪信息矩陣(Fisher Information Matrix)和黑森矩陣(Hessian Matrix)等高深術語,但我們先忘掉數學,用**“房間裝修”**來打個比方。

🏠 生活化類比:設計師的妥協

把神經網絡想象成一個剛剛裝修好的房間

  • Task A(舊任務):這是一個“家庭影院”模式。為了達到最佳視聽效果,沙發(權重 1)、音響(權重 2)、投影儀(權重 3)必須擺在特定的位置。
  • Task B(新任務):現在你想把這個房間改成“瑜伽室”。

沒有 EWC 的做法:

裝修隊進來,為了騰出瑜伽空間,直接把沙發扔出去,把音響砸了。瑜伽室很完美,但家庭影院徹底毀了。

有 EWC 的做法:

你告訴裝修隊:“有些東西你們隨便動,但有些東西很重要,動起來很費勁。”

  • 不重要的權重(比如牆角的綠植):對家庭影院影響不大,隨便移。
  • 重要的權重(比如投影儀):對家庭影院極其重要。如果你非要移動它,就像是在拉一根極其堅硬的彈簧。你可以稍微挪一點點,但挪得越遠,彈簧的反作用力(懲罰項)就越大。

EWC 的核心魔法就在於:它能自動計算出哪些傢俱(權重)是“承重牆”,哪些是“裝飾品”。

🧩 原理圖解邏輯

  1. 訓練 Task A:正常訓練,找到最優權重
  2. 搭建合約量化機器人和現貨量化機器人現在這麼火爆? - osc_權重

  3. 計算重要性(Fisher Matrix):分析 Task A 的 Loss 地形。如果某個權重稍微變動一下,Loss 就劇烈飆升,説明這個權重非常重要(地形陡峭);如果權重變了很多 Loss 還沒啥反應,説明它不重要(地形平坦)。
  4. 訓練 Task B:在 Loss 函數後面加上一個 EWC 懲罰項(那個彈簧)。

搭建合約量化機器人和現貨量化機器人現在這麼火爆? - osc_權重_02

  • :新任務的正常 Loss。
  • :這一項決定了你有多想“守舊”。值越大,越難忘。
  • 費雪信息量(重要性係數)。越重要,這一項越大,改變參數帶來的 Penalty 就越大。

3. 動手實戰:PyTorch 實現 EWC

我們將通過一個極簡的例子:先讓模型擬合一個函數,再擬合另一個函數,看它能不能同時記住兩者。

環境準備

你需要安裝 PyTorch:

pip install torch matplotlib

核心代碼解析

Python

 

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import copy

# ===========================
# 1. 定義一個簡單的神經網絡
# ===========================
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        # 為了演示,我們用一個小型的網絡
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, 2) # 輸出2類

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

# ===========================
# 2. EWC 核心類 (The Magic)
# ===========================
class EWC:
    def __init__(self, model, dataset):
        self.model = model
        self.dataset = dataset
        # 存儲舊任務的最優參數 (theta_A*)
        self.params = {n: p.data.clone() for n, p in self.model.named_parameters()}
        # 存儲每個參數的重要性 (Fisher Information)
        self.fisher = self._calculate_fisher()

    def _calculate_fisher(self):
        fisher = {}
        # 初始化 Fisher 矩陣為 0
        for n, p in self.model.named_parameters():
            fisher[n] = torch.zeros_like(p.data)

        self.model.eval()
        criterion = nn.CrossEntropyLoss()
        
        # 遍歷數據計算梯度的平方
        # 這裏的邏輯是:梯度越大,説明參數稍微一動 Loss 變化就大 -> 參數越重要
        for input_data, target in self.dataset:
            self.model.zero_grad()
            output = self.model(input_data.unsqueeze(0)) # batch_size=1
            loss = criterion(output, target.unsqueeze(0))
            loss.backward()

            for n, p in self.model.named_parameters():
                if p.grad is not None:
                    # Fisher 近似等於 梯度的平方
                    fisher[n] += p.grad.data ** 2
        
        # 歸一化
        for n in fisher:
            fisher[n] /= len(self.dataset)
            
        return fisher

    def penalty(self, new_model):
        loss = 0
        for n, p in new_model.named_parameters():
            # EWC 公式: Sum( F * (new_theta - old_theta)^2 )
            _loss = self.fisher[n] * (p - self.params[n]) ** 2
            loss += _loss.sum()
        return loss

# ===========================
# 3. 模擬訓練流程
# ===========================
def get_data(task_id):
    # 模擬數據:Task 1 輸入全1,Task 2 輸入全0
    if task_id == 1:
        return [(torch.ones(10), torch.tensor(0)) for _ in range(100)]
    else:
        return [(torch.zeros(10), torch.tensor(1)) for _ in range(100)]

# 實例化模型
model = SimpleNet()
optimizer = optim.SGD(model.parameters(), lr=0.1)
criterion = nn.CrossEntropyLoss()

print(">>> 開始訓練 任務 A (識別全1向量)")
data_a = get_data(1)
for epoch in range(5):
    for x, y in data_a:
        optimizer.zero_grad()
        loss = criterion(model(x.unsqueeze(0)), y.unsqueeze(0))
        loss.backward()
        optimizer.step()

print("任務 A 訓練完成。保存 EWC 狀態...")
# --- 關鍵步驟:計算 Task A 的重要性權重 ---
ewc = EWC(model, data_a) 

print(">>> 開始訓練 任務 B (識別全0向量),同時開啓 EWC 保護")
data_b = get_data(2)
ewc_lambda = 1000  # 懲罰力度,越大越照顧舊任務

for epoch in range(5):
    total_loss = 0
    for x, y in data_b:
        optimizer.zero_grad()
        
        # 1. 計算新任務的 Loss
        loss_b = criterion(model(x.unsqueeze(0)), y.unsqueeze(0))
        
        # 2. 計算 EWC 懲罰項 (舊任務的記憶)
        loss_ewc = ewc.penalty(model)
        
        # 3. 總 Loss
        loss = loss_b + (ewc_lambda * loss_ewc)
        
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch}: Loss = {total_loss:.4f}")

# ===========================
# 4. 驗證結果
# ===========================
model.eval()
test_a = model(torch.ones(10).unsqueeze(0))
test_b = model(torch.zeros(10).unsqueeze(0))

print("\n=== 最終測試 ===")
print(f"Task A (應為類別0): 預測概率 {torch.softmax(test_a, dim=1).detach().numpy()}")
print(f"Task B (應為類別1): 預測概率 {torch.softmax(test_b, dim=1).detach().numpy()}")

代碼劃重點:

  1. _calculate_fisher:這是 EWC 的靈魂。我們在訓練完 Task A 後,立刻凍結模型,通過反向傳播拿到梯度。注意這裏不用 optimizer.step(),我們只想要梯度值來計算 $F_i$。
  2. penalty:在訓練 Task B 時,每次迭代都會調用這個函數。它檢查當前的參數偏離舊參數有多遠,並乘以重要性係數。

4. 進階深潛:陷阱與最佳實踐

⚠️ 常見陷阱

  1. Fisher 計算開銷:在上面的代碼中,我們是一個樣本一個樣本算的(為了代碼清晰)。在生產環境中,這會非常慢。
  • 優化:使用小批量(Mini-batch)來估算 Fisher 信息,不需要遍歷整個數據集,隨機採樣幾千個樣本通常就足夠了。
  1. Lambda 的平衡
  • 太小:EWC 失效,照樣遺忘。
  • 太大:模型被舊記憶“鎖死”了,根本學不進新東西(欠擬合 Task B)。這需要像調學習率一樣去調參。

🚀 生產環境貼士

  • 多任務擴展:如果你有 Task A, B, C... 怎麼做?
  • 通常的做法是維護一個累積的 Fisher 矩陣。當你學完 Task B 準備學 C 時,你的錨點應該變成 Task B 的參數,而 Fisher 矩陣應該是 A 和 B 的重要性之和。
  • 在線 EWC (Online EWC):這是一種更高效的變體,解決了存儲多個 Fisher 矩陣帶來的內存爆炸問題。

5. 總結與延伸

核心知識點

EWC 本質上是一種正則化(Regularization)技術。它通過費雪信息矩陣識別出神經網絡中的“關鍵承重牆”,並在學習新知識時強行保護這些區域,從而在可塑性(學習新知識)穩定性(保持舊知識)之間找到平衡。