pytorch中打乱特征图的通道顺序如何做到可以像conv一样可以参与学习和反向传播给出你的代码可以实现以下想法我的想法是输入一个bchw的特征图打乱顺序后再平均切分通道数得到split_num份子特征图加入到一个列表list中。
可以使用torch.randperm函数来实现随机打乱通道顺序,然后使用torch.index_select函数按照打乱后的顺序重新排列通道。
代码如下:
import torch
def shuffle_channels(x):
batch_size, num_channels, height, width = x.size()
shuffled_idx = torch.randperm(num_channels)
x = x[:, shuffled_idx, :, :].contiguous()
split_num = 4 # 分成4份子特征图
x_list = torch.split(x, int(num_channels/split_num), dim=1)
return x_list
# 测试
x = torch.randn(2, 12, 64, 64)
x_list = shuffle_channels(x)
print(len(x_list)) # 输出4
原文地址: https://www.cveoy.top/t/topic/bx3k 著作权归作者所有。请勿转载和采集!