PyTorch 张量的广播(broadcasting)机制允许在不同形状的张量之间进行算术运算。广播会按照一定的规则自动扩展较小张量的维度,使其与较大张量的维度匹配,然后进行逐元素(element-wise)运算。
以下是 PyTorch 广播的基本规则:
下面是一个简单的例子来说明 PyTorch 中的广播机制:
import torch
# 创建两个张量
a = torch.tensor([[1., 2.], [3., 4.]])
b = torch.tensor([1., 2.])
# 广播 b 到与 a 相同的形状
b_expanded = b.expand(-1, -1)
# 进行逐元素运算
result = a + b_expanded
print(result)
输出结果:
tensor([[2., 4.],
[4., 6.]])
在这个例子中,我们首先创建了一个形状为 (2, 2) 的张量 a
和一个形状为 (2,) 的张量 b
。然后,我们使用 expand
方法将 b
扩展为与 a
相同的形状。最后,我们对两个张量进行逐元素加法运算,得到结果。