您好,登录后才能下订单哦!
本篇内容介绍了“Pytorch如何保存训练好的模型”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!
用数据对模型进行训练后得到了比较理想的模型,但在实际应用的时候不可能每次都先进行训练然后再使用,所以就得先将之前训练好的模型保存下来,然后在需要用到的时候加载一下直接使用。
模型的本质是一堆用某种结构存储起来的参数,所以在保存的时候有两种方式
一种方式是直接将整个模型保存下来,之后直接加载整个模型,但这样会比较耗内存;
另一种是只保存模型的参数,之后用到的时候再创建一个同样结构的新模型,然后把所保存的参数导入新模型。
(1)只保存模型参数字典(推荐)
#保存 torch.save(the_model.state_dict(), PATH) #读取 the_model = TheModelClass(*args, **kwargs) the_model.load_state_dict(torch.load(PATH))
(2)保存整个模型
#保存 torch.save(the_model, PATH) #读取 the_model = torch.load(PATH)
pytorch会把模型的参数放在一个字典里面,而我们所要做的就是将这个字典保存,然后再调用。
比如说设计一个单层LSTM的网络,然后进行训练,训练完之后将模型的参数字典进行保存,保存为同文件夹下面的rnn.pt文件:
class LSTM(nn.Module): def __init__(self, input_size, hidden_size, num_layers): super(LSTM, self).__init__() self.hidden_size = hidden_size self.num_layers = num_layers self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) self.fc = nn.Linear(hidden_size, 1) def forward(self, x): # Set initial states h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) # 2 for bidirection c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) # Forward propagate LSTM out, _ = self.lstm(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size*2) out = self.fc(out) return out rnn = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device) # optimize all cnn parameters optimizer = torch.optim.Adam(rnn.parameters(), lr=0.001) # the target label is not one-hotted loss_func = nn.MSELoss() for epoch in range(1000): output = rnn(train_tensor) # cnn output` loss = loss_func(output, train_labels_tensor) # cross entropy loss optimizer.zero_grad() # clear gradients for this training step loss.backward() # backpropagation, compute gradients optimizer.step() # apply gradients output_sum = output # 保存模型 torch.save(rnn.state_dict(), 'rnn.pt')
保存完之后利用这个训练完的模型对数据进行处理:
# 测试所保存的模型 m_state_dict = torch.load('rnn.pt') new_m = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device) new_m.load_state_dict(m_state_dict) predict = new_m(test_tensor)
这里做一下说明,在保存模型的时候rnn.state_dict()表示rnn这个模型的参数字典,在测试所保存的模型时要先将这个参数字典加载一下
m_state_dict = torch.load('rnn.pt');
然后再实例化一个LSTM对像,这里要保证传入的参数跟实例化rnn是传入的对象时一样的,即结构相同
new_m = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device);
下面是给这个新的模型传入之前加载的参数
new_m.load_state_dict(m_state_dict);
最后就可以利用这个模型处理数据了
predict = new_m(test_tensor)
class LSTM(nn.Module): def __init__(self, input_size, hidden_size, num_layers): super(LSTM, self).__init__() self.hidden_size = hidden_size self.num_layers = num_layers self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) self.fc = nn.Linear(hidden_size, 1) def forward(self, x): # Set initial states h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) # 2 for bidirection c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) # Forward propagate LSTM out, _ = self.lstm(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size*2) # print("output_in=", out.shape) # print("fc_in_shape=", out[:, -1, :].shape) # Decode the hidden state of the last time step # out = torch.cat((out[:, 0, :], out[-1, :, :]), axis=0) # out = self.fc(out[:, -1, :]) # 取最后一列为out out = self.fc(out) return out rnn = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device) print(rnn) optimizer = torch.optim.Adam(rnn.parameters(), lr=0.001) # optimize all cnn parameters loss_func = nn.MSELoss() # the target label is not one-hotted for epoch in range(1000): output = rnn(train_tensor) # cnn output` loss = loss_func(output, train_labels_tensor) # cross entropy loss optimizer.zero_grad() # clear gradients for this training step loss.backward() # backpropagation, compute gradients optimizer.step() # apply gradients output_sum = output # 保存模型 torch.save(rnn, 'rnn1.pt')
保存完之后利用这个训练完的模型对数据进行处理:
new_m = torch.load('rnn1.pt') predict = new_m(test_tensor)
“Pytorch如何保存训练好的模型”的内容就介绍到这里了,感谢大家的阅读。如果想了解更多行业相关的知识可以关注亿速云网站,小编将为大家输出更多高质量的实用文章!
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。