PyTorch 的 torch.nn 模塊是構建和訓練神經網絡的核心模塊,它提供了豐富的類和函數來定義和操作神經網絡。
以下是 torch.nn 模塊的一些關鍵組成部分及其功能:
- nn.Module 類
nn.Module 是所有自定義神經網絡模型的基類。用户通常會從這個類派生自己的模型類,並在其中定義網絡層結構以及前向傳播函數(forward pass)。 - 預定義層(Modules)
包括各種類型的層組件,例如卷積層(nn.Conv1d, nn.Conv2d, nn.Conv3d)、全連接層(nn.Linear)、激活函數(nn.ReLU, nn.Sigmoid, nn.Tanh)等。 - 容器類
nn.Sequential:允許將多個層按順序組合起來,形成簡單的線性堆疊網絡。
nn.ModuleList 和 nn.ModuleDict:可以動態地存儲和訪問子模塊,支持可變長度或命名的模塊集合。 - 損失函數
torch.nn 包含了一系列用於衡量模型預測與真實標籤之間差異的損失函數,例如均方誤差損失(nn.MSELoss)、交叉熵損失(nn.CrossEntropyLoss)等。 - 實用函數接口
nn.functional(通常簡寫為 F),包含了許多可以直接作用於張量上的函數,它們實現了與層對象相同的功能,但不具有參數保存和更新的能力。例如,可以使用 F.relu() 直接進行 ReLU 操作,或者 F.conv2d() 進行卷積操作。 - 初始化方法:
torch.nn.init提供了一些常用的權重初始化策略,比如 Xavier 初始化 (nn.init.xavier_uniform_()) 和 Kaiming 初始化 (nn.init.kaiming_uniform_()),這些對於成功訓練神經網絡至關重要。
1. torch.nn 模塊參考手冊
1.1 神經網絡容器
1.2 線性層
1.3 卷積層
1.4 池化層
1.5 激活函數
1.6 損失函數
1.7 歸一化層
1.8 循環神經網絡層
1.9 嵌入層
1.10 Dropout 層
1.11 實用函數
import torch
import torch.nn as nn
# 定義一個簡單的神經網絡
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(10, 20)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(20, 1)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# 創建模型和輸入
model = SimpleNet()
input = torch.randn(5, 10)
output = model(input)
print(output)
本文章為轉載內容,我們尊重原作者對文章享有的著作權。如有內容錯誤或侵權問題,歡迎原作者聯繫我們進行內容更正或刪除文章。