PyTorch 的 torch.nn 模塊是構建和訓練神經網絡的核心模塊,它提供了豐富的類和函數來定義和操作神經網絡。

以下是 torch.nn 模塊的一些關鍵組成部分及其功能:

  1. nn.Module 類
    nn.Module 是所有自定義神經網絡模型的基類。用户通常會從這個類派生自己的模型類,並在其中定義網絡層結構以及前向傳播函數(forward pass)。
  2. 預定義層(Modules)
    包括各種類型的層組件,例如卷積層(nn.Conv1d, nn.Conv2d, nn.Conv3d)、全連接層(nn.Linear)、激活函數(nn.ReLU, nn.Sigmoid, nn.Tanh)等。
  3. 容器類
    nn.Sequential:允許將多個層按順序組合起來,形成簡單的線性堆疊網絡。
    nn.ModuleList 和 nn.ModuleDict:可以動態地存儲和訪問子模塊,支持可變長度或命名的模塊集合。
  4. 損失函數
    torch.nn 包含了一系列用於衡量模型預測與真實標籤之間差異的損失函數,例如均方誤差損失(nn.MSELoss)、交叉熵損失(nn.CrossEntropyLoss)等。
  5. 實用函數接口
    nn.functional(通常簡寫為 F),包含了許多可以直接作用於張量上的函數,它們實現了與層對象相同的功能,但不具有參數保存和更新的能力。例如,可以使用 F.relu() 直接進行 ReLU 操作,或者 F.conv2d() 進行卷積操作。
  6. 初始化方法:
    torch.nn.init 提供了一些常用的權重初始化策略,比如 Xavier 初始化 (nn.init.xavier_uniform_()) 和 Kaiming 初始化 (nn.init.kaiming_uniform_()),這些對於成功訓練神經網絡至關重要。

1. torch.nn 模塊參考手冊

1.1 神經網絡容器

PyTorch中文教程 | (15) 在深度學習和NLP中使用PyTorch_pytorch 中文nlp_#pytorch

1.2 線性層

PyTorch中文教程 | (15) 在深度學習和NLP中使用PyTorch_pytorch 中文nlp_#人工智能_02

1.3 卷積層

PyTorch中文教程 | (15) 在深度學習和NLP中使用PyTorch_pytorch 中文nlp_#人工智能_03

1.4 池化層

PyTorch中文教程 | (15) 在深度學習和NLP中使用PyTorch_pytorch 中文nlp_#pytorch_04

1.5 激活函數

PyTorch中文教程 | (15) 在深度學習和NLP中使用PyTorch_pytorch 中文nlp_#pytorch_05

1.6 損失函數

PyTorch中文教程 | (15) 在深度學習和NLP中使用PyTorch_pytorch 中文nlp_#人工智能_06

1.7 歸一化層

PyTorch中文教程 | (15) 在深度學習和NLP中使用PyTorch_pytorch 中文nlp_#pytorch_07

1.8 循環神經網絡層

PyTorch中文教程 | (15) 在深度學習和NLP中使用PyTorch_pytorch 中文nlp_神經網絡_08

1.9 嵌入層

PyTorch中文教程 | (15) 在深度學習和NLP中使用PyTorch_pytorch 中文nlp_卷積_09

1.10 Dropout 層

PyTorch中文教程 | (15) 在深度學習和NLP中使用PyTorch_pytorch 中文nlp_神經網絡_10

1.11 實用函數

PyTorch中文教程 | (15) 在深度學習和NLP中使用PyTorch_pytorch 中文nlp_卷積_11

import torch
import torch.nn as nn

# 定義一個簡單的神經網絡
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 20)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(20, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# 創建模型和輸入
model = SimpleNet()
input = torch.randn(5, 10)
output = model(input)
print(output)