PyTorch实现自定义损失函数的多标签图像分类
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
# 文件路径
file_path = 'C:\Users\18105\Desktop\MVSA-Single\MVSA_Single\biaoqian.txt'
pt_file_path = 'C:\Users\18105\PycharmProjects\tuwenqingganfenxi\expanded.pt'
# 读取标签数据
tensor_list = []
with open(file_path, 'r') as file:
lines = file.readlines()
for line in lines:
line = line.strip()
numbers = line.split()
tensor = torch.tensor([float(num) for num in numbers])
tensor_list.append(tensor)
# 创建网络实例
class MyNetwork(nn.Module):
def __init__(self):
super(MyNetwork, self).__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(512 * 256, 256)
self.fc2 = nn.Linear(256, 4)
def forward(self, x):
x = self.flatten(x)
x = self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
return x
network = MyNetwork()
# 加载训练数据
data = torch.load(pt_file_path)
input_tensors = [torch.tensor(tensor) for tensor in data]
# 自定义损失函数
def custom_loss(output, label):
target_similarity = F.cosine_similarity(output, label.unsqueeze(0), dim=1)
other_similarities = []
for i, tensor in enumerate(label):
if i != torch.argmax(label):
similarity = F.cosine_similarity(output, tensor.unsqueeze(0), dim=1)
other_similarities.append(similarity)
other_similarities = torch.cat(other_similarities)
diff = target_similarity - torch.max(other_similarities)
loss = 1 - diff
return loss
# 定义优化器
optimizer = optim.AdamW(network.parameters(), lr=0.01)
# 训练网络
num_epochs = 100
for epoch in range(num_epochs):
running_loss = 0.0
correct = 0
total = 0
for i, input_tensor in enumerate(input_tensors):
# 清零梯度
optimizer.zero_grad()
# 前向传播
output = network(input_tensor)
# 计算损失
loss = custom_loss(output, tensor_list[i])
# 反向传播和优化
loss.backward()
optimizer.step()
# 统计准确率
_, predicted = torch.max(output.data, 1)
total += 1
if predicted.item() == torch.argmax(tensor_list[i]):
correct += 1
# 累计损失
running_loss += loss.item()
# 打印训练信息
print('Epoch: %d, Loss: %.3f, Accuracy: %.2f%%' % (epoch+1, running_loss, 100 * correct / total))
原文地址: https://www.cveoy.top/t/topic/j7t 著作权归作者所有。请勿转载和采集!