咱們今天來聊聊融合Transformer+LSTM+CNN,這也是有一位同學提到的。

核心點:用卷積抓短期、用 LSTM 維護狀態、用自注意力抓任意距離的依賴。

首先,咱們來看看這三位“同學”各自擅長的點在哪裏?

CNN(卷積):擅長抓“局部模式”,像短期的波峯/波谷、週期裏的固定形狀。

LSTM(長短時記憶網絡):擅長記住“時間上的因果和長期依賴”,把過去重要的記憶留着。

Transformer(自注意力):擅長把序列裏任意兩個時刻相互比較、找全局相關性,而且能並行處理。

把三人意見融合起來,就是把“局部細節 + 長期記憶 + 全局上下文”結合,通常能比單一模型更魯棒、更準確,尤其是時間序列既有短期規律又有長短期混合依賴時。



常見的融合方式有幾類(每種在不同場景下有用):

  1. 串聯:例如 CNN → LSTM → Transformer。先提取局部特徵,再用 LSTM 建長期狀態,最後用 Transformer 做全局交互。簡單直觀。
  2. 並行 → 拼接/投影融合:三個模塊並行處理同一輸入,得到三個特徵向量/序列,再把它們拼接或加權合併。適合保留“各自專長”。
  3. 注意力融合 / 交叉注意力:把一個模塊的輸出當作 query,另兩個的輸出當作 key/value,通過注意力機制動態融合。更靈活,也更強。
  4. 門控融合:用可學習的門(sigmoid/softmax)為不同專家分配權重,按權重加權求和。對噪聲魯棒。

下面把這些想法拆開並給出詳細數學表達~

核心原理

CNN(1D 卷積)

對時間維做一維卷積(kernel size = ),輸出序列 。

一個通道的卷積在時刻  的計算(不計 padding/stride 的邊界細節):

然後通常經過非線性激活(ReLU/tanh)與批歸一化/層歸一化:

若想把 CNN 輸出投影到 Transformer 的維度 :

LSTM

標準 LSTM 的逐步公式(每步輸入 ,上一步隱藏  和細胞狀態 ):

(遺忘門)(輸入門)(輸出門)(候選記憶)

其中  是 sigmoid, 表示按位乘。若批量輸入,矩陣維度按批展開。

LSTM 可以返回全部時刻的隱藏序列 ,或只返回最後隱藏 。

把 LSTM 輸出映射到 :

Transformer

位置編碼

把輸入加上位置編碼後進入編碼器。

Scaled Dot-Product Attention(單頭):

給定 :

Multi-head Attention

然後是殘差連接 + 層歸一化 + 前饋網絡(FFN):

融合的具體數學套路

A. 串聯(Serial)

示例:CNN → LSTM → Transformer

  1. CNN 提取局部特徵: (詳見第2節)。
  2. LSTM 消化局部序列: 。得到序列 。
  3. 投影並加位置編碼送入 Transformer: 。 Transformer 編碼得到 。
  4. 輸出層預測(例:直接映射到未來 H 步):

其中  可為取最後時刻、平均池化或一個小的解碼器。

B. 並行 + 拼接

三路並行得到 (都投影到相同維度 ):

再線性投影:

C. 門控融合

先對三個表示做時間/通道降維(例如取時間平均或用 Dense),得到向量 :

融合:

或者做逐元素門控(更細粒度): 然後  等。

D. 交叉注意力

把 Transformer 輸出當 query,把 CNN/LSTM 的輸出當 key/value:

這讓 Transformer 動態“向”CNN/LSTM 查詢重要的短期或長期信息。

輸出預測頭與損失函數

最簡單的輸出層:線性映射把融合向量投到未來  步:

常用的訓練損失(均方誤差 MSE):

其它:MAE、MAPE、分位數損失(若要不對稱錯誤度量)等。

以上所述,把 CNN、LSTM、Transformer 三者融合,等於把“局部模式識別 + 時序記憶 + 全局依賴建模”這三種能力組合在一起。

用卷積抓短期、用 LSTM 維護狀態、用自注意力抓任意距離的依賴,融合策略(串聯/並行/注意力/門控)決定信息如何匯聚。數學上就是把各自的輸出通過投影/注意力/加權等可學習操作合併,最後用線性層預測未來並以 MSE/MAE 等損失訓練。

完整案例

這裏,咱們分享一個案例,融合 CNN + LSTM + Transformer 進行時間序列預測,包括虛擬數據集生成、模型構建、訓練、預測等等。

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns


# 1. 時間序列數據
np.random.seed(42)
T_total = 1000# 總時間步
t = np.arange(T_total)
# 模擬一個混合信號:趨勢 + 週期 + 噪聲
series = 0.05*t + 2*np.sin(0.2*t) + 0.5*np.sin(0.05*t) + np.random.normal(0, 0.3, T_total)
series = series.astype(np.float32)

# 2. 構建 Dataset
class TimeSeriesDataset(Dataset):
    def __init__(self, data, input_len=30, pred_len=5):
        self.data = data
        self.input_len = input_len
        self.pred_len = pred_len
        self.len = len(data) - input_len - pred_len + 1
        
    def __len__(self):
        return self.len
    
    def __getitem__(self, idx):
        x = self.data[idx:idx+self.input_len]
        y = self.data[idx+self.input_len:idx+self.input_len+self.pred_len]
        return torch.from_numpy(x).unsqueeze(-1), torch.from_numpy(y).unsqueeze(-1)

input_len = 30
pred_len = 5
dataset = TimeSeriesDataset(series, input_len, pred_len)

# 拆分訓練/測試集
train_size = int(len(dataset)*0.7)
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)  # 時間序列通常不要shuffle
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# 3. 定義融合模型:CNN + LSTM + Transformer
class CNN_LSTM_Transformer(nn.Module):
    def __init__(self, input_dim=1, cnn_channels=16, lstm_hidden=32, transformer_dim=32,
                 transformer_heads=4, transformer_layers=1, pred_len=5):
        super().__init__()
        # CNN
        self.cnn = nn.Conv1d(in_channels=input_dim, out_channels=cnn_channels, kernel_size=3, padding=1)
        self.cnn_relu = nn.ReLU()
        
        # LSTM
        self.lstm = nn.LSTM(input_size=cnn_channels, hidden_size=lstm_hidden, batch_first=True)
        
        # Transformer Encoder 
        encoder_layer = nn.TransformerEncoderLayer(d_model=transformer_dim, nhead=transformer_heads, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
        
        # Projection layers
        self.proj_lstm = nn.Linear(lstm_hidden, transformer_dim)
        self.pred_len = pred_len
        self.fc_out = nn.Linear(transformer_dim, pred_len)
    
    def forward(self, x):
        # x: [batch, seq_len, 1]
        batch_size, seq_len, _ = x.shape
        # CNN expects [batch, channels, seq_len]
        cnn_out = self.cnn_relu(self.cnn(x.transpose(1,2)))  # [B, C, T]
        cnn_out = cnn_out.transpose(1,2)  # [B, T, C]
        # LSTM
        lstm_out, _ = self.lstm(cnn_out)  # [B, T, hidden]
        lstm_proj = self.proj_lstm(lstm_out)  # [B, T, transformer_dim]
        # Transformer
        trans_out = self.transformer(lstm_proj)  # [B, T, transformer_dim]
        # 取最後時間步輸出預測
        out = self.fc_out(trans_out[:, -1, :])  # [B, pred_len]
        return out.unsqueeze(-1)  # [B, pred_len, 1]

# 4. 訓練模型
device = torch.device('cuda'if torch.cuda.is_available() else'cpu')
model = CNN_LSTM_Transformer().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

epochs = 50
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        pred = model(xb)
        loss = criterion(pred, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * xb.size(0)
    print(f"Epoch {epoch+1}/{epochs}, Train Loss: {total_loss/train_size:.4f}")

# 5. 測試集預測
model.eval()
preds, trues = [], []
with torch.no_grad():
    for xb, yb in test_loader:
        xb = xb.to(device)
        pred = model(xb)
        preds.append(pred.cpu().numpy())
        trues.append(yb.numpy())
preds = np.concatenate(preds, axis=0).squeeze(-1)  # [num_samples, pred_len]
trues = np.concatenate(trues, axis=0).squeeze(-1)

# 6. 可視化分析
time_axis = np.arange(len(series))

# 圖1:原始時間序列
plt.figure(figsize=(12,4))
plt.plot(time_axis, series, color='royalblue')
plt.title("圖1:原始時間序列(趨勢 + 週期 + 噪聲)")
plt.xlabel("Time")
plt.ylabel("Value")
plt.show()

# 圖2:訓練/測試集分佈
plt.figure(figsize=(12,4))
plt.plot(time_axis[:len(series[:train_size+input_len+pred_len])], series[:train_size+input_len+pred_len], color='green', label='Train')
plt.plot(time_axis[len(series[:train_size+input_len+pred_len]):], series[len(series[:train_size+input_len+pred_len]):], color='red', label='Test')
plt.title("圖2:訓練集與測試集時間序列分佈")
plt.xlabel("Time")
plt.ylabel("Value")
plt.legend()
plt.show()

# 圖3:預測對比(測試集前50個樣本)
plt.figure(figsize=(12,4))
plt.plot(trues[:50].flatten(), color='black', label='True')
plt.plot(preds[:50].flatten(), color='orange', linestyle='--', label='Predicted')
plt.title("圖3:測試集預測對比(前50步)")
plt.xlabel("Sample Index")
plt.ylabel("Value")
plt.legend()
plt.show()

# 圖4:預測殘差分佈
residuals = trues - preds
plt.figure(figsize=(12,4))
sns.histplot(residuals.flatten(), bins=30, kde=True, color='purple')
plt.title("圖4:預測殘差分佈(True - Pred)")
plt.xlabel("Residual")
plt.ylabel("Frequency")
plt.show()

圖1(原始時間序列):顯示整個時間序列的趨勢、週期和噪聲,給出數據特徵概覽。

融合Transformer+LSTM+CNN,時間序列預測 !!_卷積

圖2(訓練/測試集分佈):展示訓練集與測試集的時間位置,保證嚴格的時間順序,避免數據泄露。

融合Transformer+LSTM+CNN,時間序列預測 !!_時間序列_02

圖3(預測對比):展示模型預測值與真實值對比,驗證模型捕捉趨勢與週期的能力。

融合Transformer+LSTM+CNN,時間序列預測 !!_卷積_03

圖4(預測殘差分佈):分析預測誤差分佈,用於觀察模型偏差或異常預測點。

融合Transformer+LSTM+CNN,時間序列預測 !!_時間序列_04

整個代碼中,數據嚴格切分,滑動窗口只用過去預測未來。

CNN + LSTM + Transformer 融合實現局部 + 長期 + 全局依賴建模。

可視化多維度分析模型表現