PyTorch中LSTM的输入和输出实例分析

发布时间:2022-07-27 09:31:54 作者:iii
来源:亿速云 阅读:338

这篇文章主要介绍“PyTorch中LSTM的输入和输出实例分析”,在日常操作中,相信很多人在PyTorch中LSTM的输入和输出实例分析问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”PyTorch中LSTM的输入和输出实例分析”的疑惑有所帮助!接下来,请跟着小编一起来学习吧!

LSTM参数

官方文档给出的解释为:

PyTorch中LSTM的输入和输出实例分析

总共有七个参数,其中只有前三个是必须的。由于大家普遍使用PyTorch的DataLoader来形成批量数据,因此batch_first也比较重要。LSTM的两个常见的应用场景为文本处理和时序预测,因此下面对每个参数我都会从这两个方面来进行具体解释。

Inputs

关于LSTM的输入,官方文档给出的定义为:

PyTorch中LSTM的输入和输出实例分析

可以看到,输入由两部分组成:input、(初始的隐状态h_0,初始的单元状态c_0)

其中input:

input(seq_len, batch_size, input_size)

(h_0, c_0):

h_0(num_directions * num_layers, batch_size, hidden_size)
c_0(num_directions * num_layers, batch_size, hidden_size)

h_0和c_0的shape一致。

 Outputs

关于LSTM的输出,官方文档给出的定义为:

PyTorch中LSTM的输入和输出实例分析

可以看到,输出也由两部分组成:otput、(隐状态h_n,单元状态c_n)

其中output的shape为:

output(seq_len, batch_size, num_directions * hidden_size)

h_n和c_n的shape保持不变,参数解释见前文。

batch_first

如果在初始化LSTM时令batch_first=True,那么input和output的shape将由:

input(seq_len, batch_size, input_size)
output(seq_len, batch_size, num_directions * hidden_size)

变为:

input(batch_size, seq_len, input_size)
output(batch_size, seq_len, num_directions * hidden_size)

即batch_size提前。

案例

简单搭建一个LSTM如下所示:

class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size, batch_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.output_size = output_size
        self.num_directions = 1 # 单向LSTM
        self.batch_size = batch_size
        self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
        self.linear = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input_seq):
        batch_size, seq_len = input_seq[0], input_seq[1]
        h_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
        c_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
        # output(batch_size, seq_len, num_directions * hidden_size)
        output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 30, 64)
        pred = self.linear(output)  # (5, 30, 1)
        pred = pred[:, -1, :]  # (5, 1)
        return pred

其中定义模型的代码为:

self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
self.linear = nn.Linear(self.hidden_size, self.output_size)

我们加上具体的数字:

self.lstm = nn.LSTM(self.input_size=1, self.hidden_size=64, self.num_layers=5, batch_first=True)
self.linear = nn.Linear(self.hidden_size=64, self.output_size=1)

再看前向传播:

def forward(self, input_seq):
    batch_size, seq_len = input_seq[0], input_seq[1]
    h_0 = torch.randn(self.num_directions * self.num_layers, batch_size, self.hidden_size).to(device)
    c_0 = torch.randn(self.num_directions * self.num_layers, batch_size, self.hidden_size).to(device)
    # input(batch_size, seq_len, input_size)
    # output(batch_size, seq_len, num_directions * hidden_size)
    output, _ = self.lstm(input_seq, (h_0, c_0))  # output(5, 30, 64)
    pred = self.linear(output) # (5, 30, 1)
    pred = pred[:, -1, :]  # (5, 1)
    return pred

假设用前30个预测下一个,则seq_len=30,batch_size=5,由于设置了batch_first=True,因此,输入到LSTM中的input的shape应该为:

input(batch_size, seq_len, input_size) = input(5, 30, 1)

经过DataLoader处理后的input_seq为:

input_seq(batch_size, seq_len, input_size) = input_seq(5, 30, 1)

然后将input_seq送入LSTM:

output, _ = self.lstm(input_seq, (h_0, c_0))  # output(5, 30, 64)

根据前文,output的shape为:

output(batch_size, seq_len, num_directions * hidden_size) = output(5, 30, 64)

全连接层的定义为:

self.linear = nn.Linear(self.hidden_size=64, self.output_size=1)

然后将output送入全连接层:

pred = self.linear(output)  # pred(5, 30, 1)

得到的预测值shape为(5, 30, 1),由于输出是输入右移,我们只需要取pred第二维度(time)中的最后一个数据:

pred = pred[:, -1, :]  # (5, 1)

这样,我们就得到了预测值,然后与label求loss,然后再反向更新参数即可。

到此,关于“PyTorch中LSTM的输入和输出实例分析”的学习就结束了,希望能够解决大家的疑惑。理论与实践的搭配能更好的帮助大家学习,快去试试吧!若想继续学习更多相关知识,请继续关注亿速云网站,小编会继续努力为大家带来更多实用的文章!

推荐阅读:
  1. 基于pytorch的lstm参数使用详解
  2. pytorch+lstm实现的pos示例

免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。

pytorch lstm

上一篇:Java集合类之Map怎么使用

下一篇:php如何判断是否为空数组

相关阅读

您好,登录后才能下订单哦!

密码登录
登录注册
其他方式登录
点击 登录注册 即表示同意《亿速云用户服务条款》