PyTorch 张量维度扩展:使用 torch.unsqueeze() 函数
在 PyTorch 中,可以使用 torch.unsqueeze() 函数来实现张量的维度扩展。这个函数可以在指定位置上增加一个维度。
下面是一个示例,展示了如何使用 torch.unsqueeze() 来扩展张量的维度:
import torch
# 创建一个2维张量
x = torch.tensor([[1, 2], [3, 4]])
# 扩展维度,将原始张量的维度沿着指定位置增加一个维度
expanded_x = torch.unsqueeze(x, dim=0)
print('原始张量的形状:', x.shape)
print('扩展后张量的形状:', expanded_x.shape)
输出结果为:
原始张量的形状: torch.Size([2, 2])
扩展后张量的形状: torch.Size([1, 2, 2])
在上述示例中,torch.unsqueeze(x, dim=0) 将原始张量 x 的维度沿着 dim=0 的位置增加了一个维度,结果张量 expanded_x 的形状变为 (1, 2, 2)。
需要注意的是,dim 参数指定了要扩展的位置,它的取值范围为 [-(D+1), D+1],其中 D 是原始张量的维度。dim=0 表示在第一个维度之前增加一个维度,dim=1 表示在第二个维度之前增加一个维度,以此类推。
原文地址: https://www.cveoy.top/t/topic/bnil 著作权归作者所有。请勿转载和采集!