基于基因表达量的患者疾病预测:使用 PyTorch 构建神经网络
基于基因表达量的患者疾病预测:使用 PyTorch 构建神经网络
本项目使用 Python 和 PyTorch 框架构建一个神经网络模型,根据基因的表达量预测患者是否患病。模型采用二分类结构,包含一个隐藏层,并使用训练集和测试集进行训练和评估。代码提供详细注释,并输出训练和测试过程中的损失值和准确率。
数据准备
- 数据格式: 数据存储在 Excel 表格中,第一行为患者状态标志'state'(1 为患病,0 为正常)和 8 个基因名称,第 0 列为患者是否患病的真值,其余列为基因的表达量。
- 训练集路径: 'C:\Users\lenovo\Desktop\HIV\PAH三个数据集\selected_genes.xlsx'
- 测试集路径: 'C:\Users\lenovo\Desktop\HIV\PAH三个数据集\GSE53408 对应lasso基因.xlsx'
模型结构
- 模型类型: 二分类模型,用于预测患者是否患病。
- 隐藏层: 1 个隐藏层,包含 4 个神经元。
- 激活函数: sigmoid 函数
训练和评估
- 训练集和测试集: 测试集参与模型训练,以便及时调整模型参数。
- 训练过程: 模型使用训练集进行训练,并记录每个训练轮次的训练集损失值和准确率。
- 测试过程: 模型使用测试集进行评估,并记录每个训练轮次的测试集损失值和准确率。
- 输出结果: 输出每个训练轮次的训练集损失值、训练集准确率、测试集损失值、测试集准确率,以及最后一次训练得到的每个样本的概率。
代码实现
# 导入所需库
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
# 读取训练集和测试集
train_data = pd.read_excel('C:\Users\lenovo\Desktop\HIV\PAH三个数据集\selected_genes.xlsx')
test_data = pd.read_excel('C:\Users\lenovo\Desktop\HIV\PAH三个数据集\GSE53408 对应lasso基因.xlsx')
# 定义自定义数据集类
class GeneDataset(Dataset):
def __init__(self, data):
self.data = data.iloc[:, 1:].values.astype(float)
self.target = data.iloc[:, 0].values.astype(float)
def __len__(self):
return len(self.target)
def __getitem__(self, idx):
return torch.from_numpy(self.data[idx]), torch.from_numpy(self.target[idx])
# 定义神经网络模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(8, 4)
self.fc2 = nn.Linear(4, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.fc1(x)
x = self.sigmoid(x)
x = self.fc2(x)
x = self.sigmoid(x)
return x
# 定义训练函数
def train(model, device, train_loader, optimizer, criterion):
model.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data.float())
loss = criterion(output, target.view(-1, 1).float())
train_loss += loss.item()
loss.backward()
optimizer.step()
predicted = torch.round(output)
total += target.size(0)
correct += predicted.eq(target.view_as(predicted)).sum().item()
train_acc = 100. * correct / total
train_loss /= len(train_loader.dataset)
print('Train set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
train_loss, correct, total, train_acc))
return train_loss, train_acc
# 定义测试函数
def test(model, device, test_loader, criterion):
model.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data.float())
test_loss += criterion(output, target.view(-1, 1).float()).item()
predicted = torch.round(output)
total += target.size(0)
correct += predicted.eq(target.view_as(predicted)).sum().item()
test_acc = 100. * correct / total
test_loss /= len(test_loader.dataset)
print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
test_loss, correct, total, test_acc))
return test_loss, test_acc
# 设置超参数
epochs = 50
lr = 0.01
batch_size = 32
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 实例化数据集和数据加载器
train_dataset = GeneDataset(train_data)
test_dataset = GeneDataset(test_data)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
# 实例化模型、损失函数和优化器
model = Net().to(device)
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=lr)
# 训练模型
for epoch in range(1, epochs + 1):
train_loss, train_acc = train(model, device, train_loader, optimizer, criterion)
test_loss, test_acc = test(model, device, test_loader, criterion)
print('Epoch {}/{}: Train set - Loss: {:.4f}, Accuracy: {:.0f}%; Test set - Loss: {:.4f}, Accuracy: {:.0f}%
'.format(
epoch, epochs, train_loss, train_acc, test_loss, test_acc))
# 输出每个样本的概率
model.eval()
with torch.no_grad():
for data, _ in test_loader:
data = data.to(device)
output = model(data.float())
print(output)
总结
本项目使用 PyTorch 构建了一个基于基因表达量的患者疾病预测模型。代码实现简单易懂,包含了完整的训练和测试流程,并提供了详细的注释。模型的性能可以通过调整超参数和模型结构进行优化。
原文地址: https://www.cveoy.top/t/topic/nfsJ 著作权归作者所有。请勿转载和采集!