在Torch中定义一个神经网络模型通常需要使用nn.Module类。下面是一个示例代码,展示了如何定义一个简单的全连接神经网络模型:
import torch
import torch.nn as nn
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
model = SimpleNN()
在上面的代码中,我们定义了一个名为SimpleNN的神经网络模型,它包含两个全连接层和一个ReLU激活函数。在__init__
方法中,我们定义了模型的各个层,然后在forward
方法中定义了数据在模型中的流动路径。
需要注意的是,在定义神经网络模型时,通常需要继承nn.Module类,并实现__init__
和forward
方法。__init__
方法用于初始化模型的结构,forward
方法用于定义数据在模型中的传播路径。