import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn

# 文件路径
file_path = 'C:\Users\18105\Desktop\MVSA-Single\MVSA_Single\biaoqian.txt'
pt_file_path = 'C:\Users\18105\PycharmProjects\tuwenqingganfenxi\expanded.pt'

# 读取标签数据
tensor_list = []
with open(file_path, 'r') as file:
    lines = file.readlines()
    for line in lines:
        line = line.strip()
        numbers = line.split()
        tensor = torch.tensor([float(num) for num in numbers])
        tensor_list.append(tensor)

# 创建网络实例
class MyNetwork(nn.Module):
    def __init__(self):
        super(MyNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(512 * 256, 256)
        self.fc2 = nn.Linear(256, 4)

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return x

network = MyNetwork()

# 加载训练数据
data = torch.load(pt_file_path)
input_tensors = [torch.tensor(tensor) for tensor in data]

# 自定义损失函数
def custom_loss(output, label):
    target_similarity = F.cosine_similarity(output, label.unsqueeze(0), dim=1)

    other_similarities = []
    for i, tensor in enumerate(label):
        if i != torch.argmax(label):
            similarity = F.cosine_similarity(output, tensor.unsqueeze(0), dim=1)
            other_similarities.append(similarity)

    other_similarities = torch.cat(other_similarities)
    diff = target_similarity - torch.max(other_similarities)
    loss = 1 - diff

    return loss

# 定义优化器
optimizer = optim.AdamW(network.parameters(), lr=0.01)

# 训练网络
num_epochs = 100
for epoch in range(num_epochs):
    running_loss = 0.0
    correct = 0
    total = 0

    for i, input_tensor in enumerate(input_tensors):
        # 清零梯度
        optimizer.zero_grad()

        # 前向传播
        output = network(input_tensor)

        # 计算损失
        loss = custom_loss(output, tensor_list[i])

        # 反向传播和优化
        loss.backward()
        optimizer.step()

        # 统计准确率
        _, predicted = torch.max(output.data, 1)
        total += 1
        if predicted.item() == torch.argmax(tensor_list[i]):
            correct += 1

        # 累计损失
        running_loss += loss.item()

    # 打印训练信息
    print('Epoch: %d, Loss: %.3f, Accuracy: %.2f%%' % (epoch+1, running_loss, 100 * correct / total))
PyTorch实现自定义损失函数的多标签图像分类

原文地址: https://www.cveoy.top/t/topic/j7t 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录