FloatTensor与Variable怎么在Pytorch中使用

发布时间:2021-03-31 17:25:12 作者:Leah
来源:亿速云 阅读:234

FloatTensor与Variable怎么在Pytorch中使用?相信很多没有经验的人对此束手无策,为此本文总结了问题出现的原因和解决方法,通过这篇文章希望你能解决这个问题。

pytorch中基本的变量类型当属FloatTensor(以下都用floattensor),而Variable(以下都用variable)是floattensor的封装,除了包含floattensor还包含有梯度信息

pytorch中的dochi给出一些对于floattensor的基本的操作,比如四则运算以及平方等(链接),这些操作对于floattensor是十分的不友好,有时候需要写一个正则化的项需要写很长的一串,比如两个floattensor之间的相加需要用torch.add()来实现

for step in range(config.total_step):

    
    # Extract multiple(5) conv feature vectors
    target_features = vgg(target)  # 每一次输入到网络中的是同样一张图片,反传优化的目标是输入的target
    content_features = vgg(Variable(content))
    style_features = vgg(Variable(style))

    style_loss = 0
    content_loss = 0
    for f1, f2, f3 in zip(target_features, content_features, style_features):
      # Compute content loss (target and content image)
      content_loss += torch.mean((f1 - f2)**2) # square 可以进行直接加-操作?可以,并且mean对所有的元素进行均值化造作

      # Reshape conv features
      _, c, h, w = f1.size() # channel height width
      f1 = f1.view(c, h * w) # reshape a vector
      f3 = f3.view(c, h * w) # reshape a vector

      # Compute gram matrix 
      f1 = torch.mm(f1, f1.t())
      f3 = torch.mm(f3, f3.t())

      # Compute style loss (target and style image)
      style_loss += torch.mean((f1 - f3)**2) / (c * h * w)  # 总共元素的数目?

其中f1与f2,f3的变量类型是Variable,作者对其直接用四则运算符进行加减,并且用python内置的**进行平方操作,然后

# -*-coding: utf-8 -*-
import torch
from torch.autograd import Variable

# dtype = torch.FloatTensor
dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Randomly initialize weights
w1 = torch.randn(D_in, H).type(dtype) # 两个权重矩阵
w2 = torch.randn(D_in, H).type(dtype)
# operate with +-*/ and **
w3 = w1-2*w2
w4 = w3**2
w5 = w4/w1


# operate the Variable with +-*/ and **
w6 = Variable(torch.randn(N, D_in).type(dtype))
w7 = Variable(torch.randn(N, D_in).type(dtype))
w8 = w6 + w7
w9 = w6*w7
w10 = w9**2
print(1)

基本上调试的结果与预期相符

FloatTensor与Variable怎么在Pytorch中使用

看完上述内容,你们掌握FloatTensor与Variable怎么在Pytorch中使用的方法了吗?如果还想学到更多技能或想了解更多相关内容,欢迎关注亿速云行业资讯频道,感谢各位的阅读!

推荐阅读:
  1. Pytorch中自动求梯度机制和Variable类的示例分析
  2. PyTorch中Variable变量的作用是什么

免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。

pytorch variable

上一篇:Summary如何在Tensorflow中使用

下一篇:pushd和popd命令怎么在Linux中使用

相关阅读

您好,登录后才能下订单哦!

密码登录
登录注册
其他方式登录
点击 登录注册 即表示同意《亿速云用户服务条款》