import itertools

import torch import torch.nn as nn import torch.optim as optim import os import scipy.io as scio from matplotlib import pyplot as plt from prettytable import PrettyTable

from My_Basic_Net.utils.OP import WeightOperation

from My_Basic_Net.utils.OP import Bn_bin_conv_pool from sklearn.model_selection import train_test_split from torch.utils.data import Dataset, DataLoader import numpy as np import random

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') classes_num = 17 batch_size = 64 lr = 0.002

########################################读数据##############################################

读取出来的data是字典格式

dataset_path = '../ECG_Dataset/ECG-17' classes = ['NSR', 'APB', 'AFL', 'AFIB', 'SVTA', 'WPW', 'PVC', 'Bigeminy', 'Trigeminy', 'VT', 'IVR', 'VFL', 'Fusion', 'LBBBB', 'RBBBB', 'SDHB', 'PR'] len_classes = len(classes) X = list() y = list()

for root, dirs, files in os.walk(dataset_path, topdown=False): for name in files: data_train = scio.loadmat( os.path.join(root, name)) # 取出字典里的value

            # arr -> list
            data_arr = data_train.get('val')
            data_list = data_arr.tolist()
            # append() 方法用于在列表的末尾追加元素,该方法的语法格式如下:
            # listname.append(obj)

            X.append(data_list[0])  # [[……]] -> [ ]
            y.append(int(os.path.basename(root)[0:2]) - 1)# 标签0和1位,name -> num

X = torch.tensor(X, dtype=torch.float32) y = torch.tensor(y, dtype=torch.long)

X_mean = torch.mean(X, dim=1, keepdim=True) X_std = torch.std(X, dim=1, keepdim=True) X = (X - X_mean) / X_std ##数据归一化,减均值除以方差

X = X.reshape((1000, 1, 3600)).to(device) # [1,3600] y = y.reshape((1000)).to(device) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

########################################训练集############################################## class TrainDatasets(Dataset): def init(self, x_train, y_train): self.len = x_train.size(0) # 取第0元素:长度 self.x_train = x_train self.y_train = y_train

def __getitem__(self, index):
    return self.x_train[index], self.y_train[index] # 返回对应样本即可

def __len__(self):
    return self.len

class TestDatasets(Dataset): def init(self, x_test, y_test): self.len = x_test.size(0) self.x_test = x_test self.y_test = y_test

def __getitem__(self, index):
    return self.x_test[index], self.y_test[index]

def __len__(self):
    return self.len

########################################封装dataloader###################################### train_dataset = TrainDatasets(X_train, y_train) test_dataset = TestDatasets(X_test, y_test) train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

########################################定义网络############################################## class ECG_Bin(nn.Module): def init(self, device): super(ECG_Bin, self).init() self.name = 'Bin_ECG' self.device = device self.classifier = nn.Sequential(

input_channels, output_channels, kernel_size, stride, padding, pad_value, pool_size, pool_stride

        Bn_bin_conv_pool( 1, 8, 9, 3, 7, 1,
                          5, 2),
        Bn_bin_conv_pool( 8, 16, 9, 1, 7, 1,
                          5, 2),
        Bn_bin_conv_pool( 16, 32, 9, 1, 7, 1,
                          5, 2),
        Bn_bin_conv_pool( 32, 32, 9, 1, 7, 1,
                          5, 2),
        Bn_bin_conv_pool( 32, classes_num, 9, 1, 7, 1,
                          5, 2)
    )
    self.dropout = nn.Dropout(p=0.5) # 防止过拟合

def forward(self, batch_data):
    batch_data = batch_data.clone().detach().requires_grad_(True).to(self.device)
    batch_data = self.classifier(batch_data)

    batch_data = self.dropout(batch_data)
    batch_data = batch_data.mean(dim=2) # 去掉一个维度

    return batch_data

model = ECG_Bin(device=device).to(device) #####################################损失函数:交叉熵#########################################
loss_fn = nn.CrossEntropyLoss().to(device)

########################################优化器############################################## optimizer = optim.Adam(model.parameters(), lr=lr) print(device)

####################################模型参数汇总####################################### from torchinfo import summary

summary(model=model, input_size=(batch_size, 1, 3600), # make sure this is 'input_size', not 'input_shape' col_names=['input_size', 'output_size', 'num_params', 'trainable'], col_width=20, row_settings=['var_names'])

weightOperation = WeightOperation(model)

epochs = 1000 seed = 110 torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) results = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': [] } best_test_acc = 0.0 best_train_acc = 0.0 best_test_epoch = 0 best_train_epoch = 0

Make sure model on target device

model.to(device)

#################################学习率梯度衰减#################################

1. 带预热的余弦退火

lr_scheduler=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,epochs,1,0.02)

2. 余弦退火:T_max指一次学习率周期的迭代次数,即T_max个epoch后重置学习率。eta_min指最小学习率,即一个周期后学习率最小会下降到的值

lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0.05,last_epoch=-1)

3. 指数衰减

torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.0002*1000, last_epoch=-1)

train_losses = [] train_acces = [] eval_losses = [] eval_acces = []

for epoch in range(epochs): ##################################开始训练######################################## train_acc = 0 train_loss= 0 model.train() # 将模型改为训练模式 correct, total = 0, 0 for X, y in train_loader: X, y = X.to(device), y.to(device) # 前向传播,得到损失 y_pred = model(X) loss = loss_fn(y_pred, y) # 反向传播,将上一次的梯度清0,反向传播,并且step更新相应的参数 optimizer.zero_grad() loss.backward() optimizer.step() # 记录误差 train_loss += loss.item() # 计算分类的准确率 _, predicted = torch.max(y_pred.data, dim=1) # 取出预测的最大值 correct += (predicted == y).sum().cpu().item() # 判断预测是否正确 total += len(y)

train_loss = train_loss / len(train_loader)
train_acc = correct/total
train_losses.append(train_loss / len(train_loader))
train_acces.append(correct / total)

##########################每进行一次迭代,就去测试一次##################################
model.eval()
test_loss, test_acc = 0, 0
correct, total = 0, 0

with torch.inference_mode():
    for (X, y) in test_loader:
       
        X, y = X.to(device), y.to(device)
        test_pred_logits = model(X)

        loss = loss_fn(test_pred_logits, y)

        test_loss += loss.item()

        _, predicted = torch.max(test_pred_logits.data, dim=1) # 输出概率最大的标签
        total += len(y)
        correct += (predicted == y).sum().cpu().item() # 判断是否预测准确

test_loss = test_loss / len(test_loader)
test_acc = correct/total
eval_losses.append(test_loss / len(train_loader))
eval_acces.append(correct / total)



if test_acc > best_test_acc:
        best_test_acc = test_acc
        best_test_epoch = epoch + 1
        acc_str = '{:.2f}'.format(best_test_acc * 100) + '%'
        
if train_acc > best_train_acc:
        best_train_acc = train_acc
        best_train_epoch = epoch + 1
        acc_str = '{:.2f}'.format(best_test_acc * 100) + '%'

print(
    f'Epoch: {epoch + 1} | '
    f'train_loss: {train_loss:.4f} | '
    f'train_acc: {train_acc:.4f} | '
    f'test_loss: {test_loss:.4f} | '
    f'test_acc: {test_acc:.4f}'
)

   
results['train_loss'].append(train_loss)
results['train_acc'].append(train_acc)
results['test_loss'].append(test_loss)
results['test_acc'].append(test_acc)

##################################

weightOperation.WeightBinarize()

print(model.state_dict())

print('best_test_acc: ', '{:.2f}'.format(best_test_acc * 100) + '%', ' epoch: ', best_test_epoch) print('best_train_acc: ', '{:.2f}'.format(best_train_acc * 100) + '%', ' epoch: ', best_train_epoch) print('-' * 50 + ' ')

plt.plot(np.arange(len(train_losses)), train_losses,label = 'train loss') plt.plot(np.arange(len(train_acces)), train_acces,label = 'train acc') plt.plot(np.arange(len(eval_losses)), eval_losses,label = 'test loss') plt.plot(np.arange(len(eval_acces)), eval_acces,label = 'test acc') plt.legend() plt.xlabel('epoches') plt.title('Model loss & accuracy') plt.show()


原文地址: https://www.cveoy.top/t/topic/mmkt 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录