pytorch

pytorch张量怎样广播

小樊
81
2024-12-25 20:52:26
栏目: 深度学习

PyTorch 张量的广播(broadcasting)机制允许在不同形状的张量之间进行算术运算。广播会按照一定的规则自动扩展较小张量的维度,使其与较大张量的维度匹配,然后进行逐元素(element-wise)运算。

以下是 PyTorch 广播的基本规则:

  1. 如果两个张量的维度不同,将较小的维度扩展为与较大的维度相同。扩展的方式是在较小的维度前面补 1。
  2. 如果两个张量在某一个维度上的大小相同,或者其中一个张量在该维度上的大小为 1,则这两个张量在该维度上可以进行广播。
  3. 如果两个张量在所有维度上的大小都相同,那么它们可以直接进行逐元素运算。否则,会抛出错误。

下面是一个简单的例子来说明 PyTorch 中的广播机制:

import torch

# 创建两个张量
a = torch.tensor([[1., 2.], [3., 4.]])
b = torch.tensor([1., 2.])

# 广播 b 到与 a 相同的形状
b_expanded = b.expand(-1, -1)

# 进行逐元素运算
result = a + b_expanded
print(result)

输出结果:

tensor([[2., 4.],
        [4., 6.]])

在这个例子中,我们首先创建了一个形状为 (2, 2) 的张量 a 和一个形状为 (2,) 的张量 b。然后,我们使用 expand 方法将 b 扩展为与 a 相同的形状。最后,我们对两个张量进行逐元素加法运算,得到结果。

0
看了该问题的人还看了