parameter与buffer怎么在Pytorch模型中使用

发布时间:2021-06-01 16:19:34 作者:Leah
来源:亿速云 阅读:165

本篇文章给大家分享的是有关parameter与buffer怎么在Pytorch模型中使用,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。

第一种参数有两种方式

我们可以直接将模型的成员变量(http://self.xxx) 通过nn.Parameter() 创建,会自动注册到parameters中,可以通过model.parameters() 返回,并且这样创建的参数会自动保存到OrderDict中去;

通过nn.Parameter() 创建普通Parameter对象,不作为模型的成员变量,然后将Parameter对象通过register_parameter()进行注册,可以通model.parameters() 返回,注册后的参数也会自动保存到OrderDict中去;

第二种参数我们需要创建tensor

然后将tensor通过register_buffer()进行注册,可以通model.buffers() 返回,注册完后参数也会自动保存到OrderDict中去。

Pytorch中Module,Parameter和Buffer区别

下文都将torch.nn简写成nn

Module: 就是我们常用的torch.nn.Module类,你定义的所有网络结构都必须继承这个类。

Buffer: buffer和parameter相对,就是指那些不需要参与反向传播的参数

示例如下:

class MyModel(nn.Module):
 def __init__(self):
  super(MyModel, self).__init__()
  self.my_tensor = torch.randn(1) # 参数直接作为模型类成员变量
  self.register_buffer('my_buffer', torch.randn(1)) # 参数注册为 buffer
  self.my_param = nn.Parameter(torch.randn(1))
 def forward(self, x):
  return x 

model = MyModel()
print(model.state_dict())
>>>OrderedDict([('my_param', tensor([1.2357])), ('my_buffer', tensor([-0.9982]))])
Parameter: 是nn.parameter.Paramter,也就是组成Module的参数。例如一个nn.Linear通常由weight和bias参数组成。它的特点是默认requires_grad=True,也就是说训练过程中需要反向传播的,就需要使用这个
import torch.nn as nn
fc = nn.Linear(2,2)

# 读取参数的方式一
fc._parameters
>>> OrderedDict([('weight', Parameter containing:
              tensor([[0.4142, 0.0424],
                      [0.3940, 0.0796]], requires_grad=True)),
             ('bias', Parameter containing:
              tensor([-0.2885,  0.5825], requires_grad=True))])
     
# 读取参数的方式二(推荐这种)
for n, p in fc.named_parameters():
 print(n,p)
>>>weight Parameter containing:
tensor([[0.4142, 0.0424],
        [0.3940, 0.0796]], requires_grad=True)
bias Parameter containing:
tensor([-0.2885,  0.5825], requires_grad=True)

# 读取参数的方式三
for p in fc.parameters():
 print(p)
>>>Parameter containing:
tensor([[0.4142, 0.0424],
        [0.3940, 0.0796]], requires_grad=True)
Parameter containing:
tensor([-0.2885,  0.5825], requires_grad=True)

通过上面的例子可以看到,nn.parameter.Paramter的requires_grad属性值默认为True。另外上面例子给出了三种读取parameter的方法,推荐使用后面两种,因为是以迭代生成器的方式来读取,第一种方式是一股脑的把参数全丢给你,要是模型很大,估计你的电脑会吃不消。

另外需要介绍的是_parameters是nn.Module在__init__()函数中就定义了的一个OrderDict类,这个可以通过看下面给出的部分源码看到,可以看到还初始化了很多其他东西,其实原理都大同小异,你理解了这个之后,其他的也是同样的道理。

class Module(object):
 ...
    def __init__(self):
        self._backend = thnn_backend
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._state_dict_hooks = OrderedDict()
        self._load_state_dict_pre_hooks = OrderedDict()
        self._modules = OrderedDict()
        self.training = True

每当我们给一个成员变量定义一个nn.parameter.Paramter的时候,都会自动注册到_parameters,具体的步骤如下:

import torch.nn as nn
class MyModel(nn.Module):
 def __init__(self):
  super(MyModel, self).__init__()
  # 下面两种定义方式均可
  self.p1 = nn.paramter.Paramter(torch.tensor(1.0))
  print(self._parameters)
  self.p2 = nn.Paramter(torch.tensor(2.0))
  print(self._parameters)

首先运行super(MyModel, self).__init__(),这样MyModel就初始化了_paramters等一系列的OrderDict,此时所有变量还都是空的。

self.p1 = nn.paramter.Paramter(torch.tensor(1.0)): 这行代码会触发nn.Module预定义好的__setattr__函数,该函数部分源码如下:

def __setattr__(self, name, value):
 ...
 params = self.__dict__.get('_parameters')
 if isinstance(value, Parameter):
  if params is None:
   raise AttributeError(
    "cannot assign parameters before Module.__init__() call")
  remove_from(self.__dict__, self._buffers, self._modules)
  self.register_parameter(name, value)
 ...

__setattr__函数作用简单理解就是判断你定义的参数是否正确,如果正确就继续调用register_parameter函数进行注册,这个函数简单概括就是做了下面这件事

def register_parameter(self,name,param):
 ...
 self._parameters[name]=param

下面我们实例化这个模型看结果怎样

model = MyModel()
>>>OrderedDict([('p1', Parameter containing:
tensor(1., requires_grad=True))])
OrderedDict([('p1', Parameter containing:
tensor(1., requires_grad=True)), ('p2', Parameter containing:
tensor(2., requires_grad=True))])

以上就是parameter与buffer怎么在Pytorch模型中使用,小编相信有部分知识点可能是我们日常工作会见到或用到的。希望你能通过这篇文章学到更多知识。更多详情敬请关注亿速云行业资讯频道。

推荐阅读:
  1. 如何在pytorch中存储模型
  2. PyTorch如何使用预训练模型

免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。

pytorch parameter buffer

上一篇:php中swfupload乱码如何解决

下一篇:使用BootStrap怎么实现栅格布局

相关阅读

您好,登录后才能下订单哦!

密码登录
登录注册
其他方式登录
点击 登录注册 即表示同意《亿速云用户服务条款》