在 PyTorch 的张量中复制某一维度,可以使用 torch.repeat()torch.repeat_interleave() 函数。

  • 使用 torch.repeat() 函数可以在指定维度上复制张量的元素。该函数接受一个整数参数,表示要复制的次数。例如,要在第 0 维上复制张量,可以使用以下代码:
import torch

x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y = x.repeat(3, 1)
print(y)

输出为:

tensor([[1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6]])
  • 使用 torch.repeat_interleave() 函数可以在指定维度上复制并交叉插入张量的元素。该函数接受一个整数参数,表示要复制的次数。例如,要在第 0 维上复制并交叉插入张量,可以使用以下代码:
import torch

x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y = x.repeat_interleave(3, 0)
print(y)

输出为:

tensor([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3],
        [4, 5, 6],
        [4, 5, 6],
        [4, 5, 6]])

这两个函数都会返回复制后的新张量,并不会改变原始张量的值。

PyTorch 张量在指定维度复制内容:repeat() 和 repeat_interleave() 函数

原文地址: http://www.cveoy.top/t/topic/bncV 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录