PyTorch 中的 torch.concattorch.stack 都是用来将多个张量合并的函数,但它们在合并方式上有所不同:

  1. torch.concat:沿着指定维度拼接张量

    torch.concat 函数将多个张量沿着指定的维度进行拼接,返回一个新的张量。拼接后的张量形状在指定维度上的长度等于所有拼接的张量在该维度上的长度之和。

    import torch
    
    a = torch.randn(2, 3)
    b = torch.randn(2, 3)
    c = torch.concat([a, b], dim=0)
    
    print(a)
    print(b)
    print(c)
    

    输出:

    tensor([[ 1.3023, -0.6175,  0.2248],
            [-0.0512, -0.2462, -1.4557]])
    tensor([[-0.7087, -0.3726, -0.7841],
            [ 1.2778,  0.5525,  0.2949]])
    tensor([[ 1.3023, -0.6175,  0.2248],
            [-0.0512, -0.2462, -1.4557],
            [-0.7087, -0.3726, -0.7841],
            [ 1.2778,  0.5525,  0.2949]])
    

    在这个例子中,我们定义了两个 2x3 的张量 ab,然后使用 torch.concat 沿着第 0 维进行拼接,得到一个 4x3 的新张量 c

  2. torch.stack:沿着新维度堆叠张量

    torch.stack 函数将多个张量沿着新的维度进行堆叠,返回一个新的张量。新的维度的长度等于堆叠的张量的数量。

    import torch
    
    a = torch.randn(2, 3)
    b = torch.randn(2, 3)
    c = torch.stack([a, b], dim=0)
    
    print(a)
    print(b)
    print(c)
    

    输出:

    tensor([[-0.1772,  0.1090,  0.0829],
            [ 0.4024, -0.2663, -0.9918]])
    tensor([[-0.1169, -0.5678,  0.5602],
            [-0.3511,  0.0914,  0.3679]])
    tensor([[[-0.1772,  0.1090,  0.0829],
             [ 0.4024, -0.2663, -0.9918]],
    
            [[-0.1169, -0.5678,  0.5602],
             [-0.3511,  0.0914,  0.3679]]])
    

    在这个例子中,我们定义了两个 2x3 的张量 ab,然后使用 torch.stack 沿着新的第 0 维进行堆叠,得到一个 2x2x3 的新张量 c。注意,新的第 0 维的长度等于堆叠的张量的数量,即 2。

总结:

  • torch.concat 在现有维度上拼接张量,不增加新的维度。
  • torch.stack 沿着新的维度堆叠张量,增加新的维度。

选择 torch.concat 还是 torch.stack 取决于你想要合并张量的具体方式。如果你想要将多个张量在现有维度上拼接,使用 torch.concat。如果你想要将多个张量沿着新的维度堆叠,使用 torch.stack

PyTorch 中的 torch.concat 和 torch.stack:详细比较与示例

原文地址: https://www.cveoy.top/t/topic/m6qS 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录