在PyTorch中,nn.Linear
是一个用于定义线性变换的类。可以使用它来定义一个全连接层(也称为线性层)。
以下是如何使用nn.Linear
的例子:
首先,导入需要的模块:
import torch
import torch.nn as nn
接下来,定义一个包含输入和输出大小的线性层:
input_size = 10
output_size = 5
linear_layer = nn.Linear(input_size, output_size)
这将创建一个线性层,将输入维度为10的特征映射到输出维度为5的特征。
然后,可以将数据传递给线性层进行转换:
input_data = torch.randn(1, input_size)
output_data = linear_layer(input_data)
这将根据线性层的权重和偏差将输入数据进行线性变换,并返回输出数据。
最后,可以查看线性层的权重和偏差:
print(linear_layer.weight)
print(linear_layer.bias)
这将打印出线性层的权重矩阵和偏差向量。
注意:nn.Linear
类还可以接受一些其他参数,例如是否添加偏差(默认为True)、权重初始化方法等。你可以查阅PyTorch的官方文档以获取更多详细信息。