在PyTorch中,可以使用torch.nn.Transformer
类来调用Transformer模型。以下是一个使用Transformer模型的示例代码:
import torch
import torch.nn as nn
# 定义Transformer模型
class TransformerModel(nn.Module):
def __init__(self, input_dim, hidden_dim, num_layers, num_heads):
super(TransformerModel, self).__init__()
self.transformer = nn.Transformer(
d_model=input_dim,
nhead=num_heads,
num_encoder_layers=num_layers,
num_decoder_layers=num_layers,
dim_feedforward=hidden_dim
)
def forward(self, src, tgt):
output = self.transformer(src, tgt)
return output
# 创建Transformer模型实例
input_dim = 512
hidden_dim = 2048
num_layers = 6
num_heads = 8
model = TransformerModel(input_dim, hidden_dim, num_layers, num_heads)
# 准备输入数据
batch_size = 16
src_seq_len = 10
tgt_seq_len = 5
src = torch.randn(batch_size, src_seq_len, input_dim)
tgt = torch.randn(batch_size, tgt_seq_len, input_dim)
# 前向传播
output = model(src, tgt)
在这个示例中,我们首先定义了一个继承自nn.Module
的自定义Transformer模型类TransformerModel
。在__init__
方法中,我们使用nn.Transformer
类来创建一个Transformer模型,并指定输入维度、隐藏层维度、编码器和解码器的层数,以及注意力头数。在forward
方法中,我们将输入数据传入Transformer模型进行前向传播,并返回输出。
然后,我们创建了一个Transformer模型实例,并准备了输入数据。最后,我们通过调用模型的forward
方法来进行前向传播,并得到输出结果。