在my_net类中,第三个全连接层使用了fc2两次,应该将其改为fc3。同时,将第三个全连接层fc3的输出维度改为10,与输出类别数目一致。修改后的代码如下:

class my_net(nn.Cell): # 定义算子 def init(self): super(my_net, self).init() self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, pad_mode='valid') self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, pad_mode='valid')

    self.relu = nn.ReLU()  # 激活函数
    self.pool = nn.MaxPool2d(kernel_size=2,stride=2)
    self.flatten = nn.Flatten()
    self.dropout = nn.Dropout(0.2)
    self.fc1 = nn.Dense(4*4*16, 120)  # 全连接层
    self.fc2 = nn.Dense(120, 84)  # 全连接层
    self.fc3 = nn.Dense(84, 10)  # 全连接层

# 定义前向传播
def construct(self, x):
    x = self.conv1(x)
    x = self.relu(x)
    x = self.pool(x)
    
    x = self.conv2(x)
    x = self.relu(x)
    x = self.pool(x)
    
    x = self.flatten(x)
    x = self.fc1(x)
    x = self.relu(x)
    x = self.dropout(x)
    
    x = self.fc2(x)
    x = self.relu(x)
    x = self.dropout(x)
    
    x = self.fc3(x)
    
    return x

原文地址: https://www.cveoy.top/t/topic/bd2g 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录