PyTorch 一维CNN模型搭建与输入维度调整
PyTorch 一维CNN模型搭建与输入维度调整
在使用PyTorch构建一维卷积神经网络 (1D-CNN) 时,确保输入数据的维度与模型相匹配至关重要。本文将介绍如何搭建一个简单的1D-CNN模型,并重点解析如何调整输入数据的维度以解决常见的 RuntimeError 错误。
1. 模型搭建
以下代码展示了一个简单的1D-CNN模型,包含多个卷积层、平均池化层、全局平均池化层以及全连接层:pythonimport torchimport torch.nn as nn
class CNN(nn.Module): def init(self): super(CNN, self).init() self.conv1 = nn.Sequential( nn.Conv1d(1, 5, kernel_size=12, stride=1, padding=1), nn.Mish() ) # ... 其他卷积层 ... self.conv8 = nn.Sequential( nn.Conv1d(35, 40, kernel_size=12, stride=1, padding=1), nn.Mish() ) self.avgpool = nn.AvgPool1d(kernel_size=2, stride=1) self.global_avgpool = nn.AdaptiveAvgPool1d(1) self.fc = nn.Linear(40, 6) self.softmax = nn.Softmax(dim=1) self.loss_fn = nn.CrossEntropyLoss()
def forward(self, x): # ... 输入维度调整 ... x = self.conv1(x) x = self.avgpool(x) # ... 其他卷积和池化操作 ... x = self.conv8(x) x = self.global_avgpool(x) x = x.view(x.size(0), -1) feature = x x = self.fc(x) output = self.softmax(x) return feature, output
2. 输入维度调整
1D-CNN 要求输入数据的维度为三维:[batch_size, channels, sequence_length]。其中:
batch_size表示批次大小。*channels表示通道数。*sequence_length表示序列长度。
如果输入数据的维度不符合要求,例如是 [batch_size, sequence_length, channels],则需要进行调整。
在上述代码的 forward 函数中,我们假设输入数据的维度为 [100, 10000, 12],其中通道数为12。为了匹配 1D-CNN 的要求,我们需要进行以下调整:python def forward(self, x): x = torch.squeeze(x, dim=1) # 移除维度为1的维度 x = x.unsqueeze(1) # 在第二个维度上增加维度 # ... 后续操作 ..
原文地址: https://www.cveoy.top/t/topic/dejb 著作权归作者所有。请勿转载和采集!