PyTorch 中的 concatstack 都可以将多个张量按照指定维度进行拼接,但它们的拼接方式有所不同。

torch.cat 的作用是将多个张量在指定维度上进行拼接,拼接后的张量在指定维度上的大小等于所有拼接的张量在该维度上的大小之和。例如:

import torch

# 创建三个张量,shape 分别为 (2, 2), (2, 2), (2, 2)
x = torch.randn(2, 2)
y = torch.randn(2, 2)
z = torch.randn(2, 2)

# 在维度 0 上拼接三个张量
result = torch.cat([x, y, z], dim=0)
print(result)
print(result.shape)

输出结果为:

tensor([[ 0.3588,  1.0895],
        [ 0.6425,  0.8629],
        [-2.1657,  0.1006],
        [-1.0413,  0.4479],
        [-0.2289, -1.2708],
        [-0.6963, -0.4206]])
torch.Size([6, 2])

torch.stack 的作用是将多个张量在新的维度上进行拼接,拼接后的张量的新维度大小等于拼接的张量的个数。例如:

import torch

# 创建三个张量,shape 分别为 (2, 2), (2, 2), (2, 2)
x = torch.randn(2, 2)
y = torch.randn(2, 2)
z = torch.randn(2, 2)

# 在新的维度上拼接三个张量
result = torch.stack([x, y, z], dim=0)
print(result)
print(result.shape)

输出结果为:

tensor([[[ 0.2097,  0.8226],
         [ 0.1723, -0.8948]],

        [[ 0.6933, -0.2885],
         [-0.5579, -0.4492]],

        [[-0.5385,  0.6642],
         [-0.0817,  0.8586]]])
torch.Size([3, 2, 2])

可以看到,torch.stack 在新的维度上创建了一个新的维度,而 torch.cat 则是将原有的维度进行拼接。

PyTorch 中 concat 和 stack 的区别及示例

原文地址: https://www.cveoy.top/t/topic/m6s6 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录