PyTorch 中的 torch.concat 和 torch.stack:详细比较与示例
PyTorch 中的 torch.concat 和 torch.stack 都是用来将多个张量合并的函数,但它们在合并方式上有所不同:
-
torch.concat:沿着指定维度拼接张量torch.concat函数将多个张量沿着指定的维度进行拼接,返回一个新的张量。拼接后的张量形状在指定维度上的长度等于所有拼接的张量在该维度上的长度之和。import torch a = torch.randn(2, 3) b = torch.randn(2, 3) c = torch.concat([a, b], dim=0) print(a) print(b) print(c)输出:
tensor([[ 1.3023, -0.6175, 0.2248], [-0.0512, -0.2462, -1.4557]]) tensor([[-0.7087, -0.3726, -0.7841], [ 1.2778, 0.5525, 0.2949]]) tensor([[ 1.3023, -0.6175, 0.2248], [-0.0512, -0.2462, -1.4557], [-0.7087, -0.3726, -0.7841], [ 1.2778, 0.5525, 0.2949]])在这个例子中,我们定义了两个 2x3 的张量
a和b,然后使用torch.concat沿着第 0 维进行拼接,得到一个 4x3 的新张量c。 -
torch.stack:沿着新维度堆叠张量torch.stack函数将多个张量沿着新的维度进行堆叠,返回一个新的张量。新的维度的长度等于堆叠的张量的数量。import torch a = torch.randn(2, 3) b = torch.randn(2, 3) c = torch.stack([a, b], dim=0) print(a) print(b) print(c)输出:
tensor([[-0.1772, 0.1090, 0.0829], [ 0.4024, -0.2663, -0.9918]]) tensor([[-0.1169, -0.5678, 0.5602], [-0.3511, 0.0914, 0.3679]]) tensor([[[-0.1772, 0.1090, 0.0829], [ 0.4024, -0.2663, -0.9918]], [[-0.1169, -0.5678, 0.5602], [-0.3511, 0.0914, 0.3679]]])在这个例子中,我们定义了两个 2x3 的张量
a和b,然后使用torch.stack沿着新的第 0 维进行堆叠,得到一个 2x2x3 的新张量c。注意,新的第 0 维的长度等于堆叠的张量的数量,即 2。
总结:
torch.concat在现有维度上拼接张量,不增加新的维度。torch.stack沿着新的维度堆叠张量,增加新的维度。
选择 torch.concat 还是 torch.stack 取决于你想要合并张量的具体方式。如果你想要将多个张量在现有维度上拼接,使用 torch.concat。如果你想要将多个张量沿着新的维度堆叠,使用 torch.stack。
原文地址: https://www.cveoy.top/t/topic/m6qS 著作权归作者所有。请勿转载和采集!