pytorch

pytorch全连接层权重怎么获取

小樊
81
2024-12-25 19:47:24
栏目: 深度学习

在PyTorch中,要获取全连接层的权重,首先需要访问模型的每一层,然后通过层的权重属性来获取权重。以下是一个简单的示例,展示了如何获取一个具有单个全连接层的简单神经网络的权重:

import torch
import torch.nn as nn

# 定义一个简单的神经网络
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc = nn.Linear(784, 10)  # 假设输入特征数为784,输出类别数为10

    def forward(self, x):
        return self.fc(x)

# 创建网络实例
net = SimpleNet()

# 获取全连接层的权重
weights = net.fc.weight.data

# 打印权重
print("Weights:", weights)

在这个例子中,我们定义了一个名为SimpleNet的简单神经网络,其中包含一个全连接层fc。我们创建了一个网络实例net,然后通过访问net.fc.weight.data来获取全连接层的权重。最后,我们打印了这些权重。

0
看了该问题的人还看了