PyTorch 张量在指定维度复制内容:repeat() 和 repeat_interleave() 函数
在 PyTorch 的张量中复制某一维度,可以使用 torch.repeat() 或 torch.repeat_interleave() 函数。
- 使用
torch.repeat()函数可以在指定维度上复制张量的元素。该函数接受一个整数参数,表示要复制的次数。例如,要在第 0 维上复制张量,可以使用以下代码:
import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y = x.repeat(3, 1)
print(y)
输出为:
tensor([[1, 2, 3],
[4, 5, 6],
[1, 2, 3],
[4, 5, 6],
[1, 2, 3],
[4, 5, 6]])
- 使用
torch.repeat_interleave()函数可以在指定维度上复制并交叉插入张量的元素。该函数接受一个整数参数,表示要复制的次数。例如,要在第 0 维上复制并交叉插入张量,可以使用以下代码:
import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y = x.repeat_interleave(3, 0)
print(y)
输出为:
tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3],
[4, 5, 6],
[4, 5, 6],
[4, 5, 6]])
这两个函数都会返回复制后的新张量,并不会改变原始张量的值。
原文地址: http://www.cveoy.top/t/topic/bncV 著作权归作者所有。请勿转载和采集!