PyTorch CNN 模型构建与报错解决:输入数据维度不匹配
PyTorch CNN 模型构建与报错解决:输入数据维度不匹配
本文将介绍如何使用 PyTorch 构建一个简单的 CNN 模型,并解决输入数据维度不匹配导致的报错问题。
首先,我们先来看一下 CNN 模型的构建代码:
import torch
import torch.nn as nn
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv1d(1, 5, kernel_size=12, stride=1, padding=1),
nn.Mish()
)
self.conv2 = nn.Sequential(
nn.Conv1d(5, 10, kernel_size=12, stride=1, padding=1),
nn.Mish()
)
self.conv3 = nn.Sequential(
nn.Conv1d(10, 15, kernel_size=12, stride=1, padding=1),
nn.Mish()
)
self.conv4 = nn.Sequential(
nn.Conv1d(15, 20, kernel_size=12, stride=1, padding=1),
nn.Mish()
)
self.conv5 = nn.Sequential(
nn.Conv1d(20, 25, kernel_size=12, stride=1, padding=1),
nn.Mish()
)
self.conv6 = nn.Sequential(
nn.Conv1d(25, 30, kernel_size=12, stride=1, padding=1),
nn.Mish()
)
self.conv7 = nn.Sequential(
nn.Conv1d(30, 35, kernel_size=12, stride=1, padding=1),
nn.Mish()
)
self.conv8 = nn.Sequential(
nn.Conv1d(35, 40, kernel_size=12, stride=1, padding=1),
nn.Mish()
)
self.avgpool = nn.AvgPool1d(kernel_size=2, stride=1)
self.global_avgpool = nn.AdaptiveAvgPool1d(1)
self.fc = nn.Linear(40, 6)
self.softmax = nn.Softmax(dim=1)
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, x):
x = torch.squeeze(x,dim=1)
x = self.conv1(x)
x = self.avgpool(x)
x = self.conv2(x)
x = self.avgpool(x)
x = self.conv3(x)
x = self.avgpool(x)
x = self.conv4(x)
x = self.avgpool(x)
x = self.conv5(x)
x = self.avgpool(x)
x = self.conv6(x)
x = self.avgpool(x)
x = self.conv7(x)
x = self.avgpool(x)
x = self.conv8(x)
x = self.global_avgpool(x)
x = x.view(x.size(0), -1)
feature = x
x = self.fc(x)
output = self.softmax(x)
return feature, output
在运行模型时,我们可能会遇到以下报错:
RuntimeError: Given groups=1, weight of size [5, 1, 12], expected input[100, 10000, 12] to have 1 channels, but got 10000 channels instead
这个报错信息告诉我们,卷积层期望输入数据的通道数为 1,而实际输入的通道数为 10000。那么,[5, 1, 12] 这三个数字代表什么呢?
[5, 1, 12] 表示卷积层的参数权重的形状,其中:
- 5 表示输出通道数;
- 1 表示输入通道数;
- 12 表示卷积核的大小。
因此,报错的原因就是输入数据的维度不匹配,导致无法进行卷积操作。
为了解决这个问题,我们需要对输入数据进行 reshape,使其符合期望的通道数。在 forward 函数中的第一行代码 x = torch.squeeze(x,dim=1) 将输入的维度从 [100, 10000, 12] 转换为 [100, 10000],但这样会丢失通道信息。
为了保留通道信息,我们可以将输入数据的维度调整为 [100, 1, 10000, 12],其中 1 表示输入通道数。这样就符合了期望的通道数,同时不会丢失数据。
可以在 forward 函数中的第一行代码之后添加以下代码:
x = x.unsqueeze(1)
修改后的 forward 函数如下:
def forward(self, x):
x = torch.squeeze(x, dim=1)
x = x.unsqueeze(1)
x = self.conv1(x)
...
这样就可以解决报错问题,并保留输入数据的通道信息。
总结
本文介绍了如何使用 PyTorch 构建一个简单的 CNN 模型,并解决输入数据维度不匹配导致的报错问题。文章详细解释了报错原因,并提供了解决方案,帮助读者理解 CNN 模型的构建过程以及如何处理常见的错误。希望本文能够对读者有所帮助。
原文地址: https://www.cveoy.top/t/topic/ddM9 著作权归作者所有。请勿转载和采集!