PyTorch LeNet_1D_V 多分类模型:使用交叉熵损失函数

要求模型的多分类损失函数为交叉熵函数,可以使用 PyTorch 中的nn.CrossEntropyLoss()函数作为损失函数。

在模型的 forward 函数中,输出的output是经过 softmax 函数处理后的结果,可以直接将其作为预测概率输入到交叉熵函数中计算损失。

修改代码如下:

import torch.nn as nn

class LeNet_1D_V(nn.Module):
    def __init__(self):
        super(LeNet_1D_V, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv1d(1, 5, kernel_size=3, stride=1, padding=1),
            nn.Mish()
        )
        # 省略其他层的定义
        self.softmax = nn.Softmax(dim=1)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x, target=None):
        x = self.conv1(x)
        # 省略其他层的计算
        x = self.global_avgpool(x)
        x = x.view(x.size(0), -1)
        feature = x
        x = self.fc(x)
        output = self.softmax(x)

        if target is not None:
            loss = self.loss_fn(output, target)
            return feature, output, loss
        else:
            return feature, output

model = LeNet_1D_V()
print(model)

在模型的 forward 函数中加入了一个可选的参数target,用于传入真实标签。如果传入了target,则计算交叉熵损失并返回。如果没有传入target,则只返回特征和预测概率。

使用交叉熵损失函数时,训练过程中需要将真实标签传入模型进行损失计算,并根据损失进行反向传播和参数更新。

PyTorch LeNet_1D_V 多分类模型:使用交叉熵损失函数

原文地址: https://www.cveoy.top/t/topic/b76B 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录