使用 ALBERT 模型进行八分类任务:完整 Python 代码和数据集示例
以下是使用 ALBERT 模型对数据集进行八分类的完整 Python 代码:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AlbertTokenizer, AlbertModel
# 定义数据集类
class MyDataset(Dataset):
def __init__(self, file_path):
self.samples = []
self.labels = []
with open(file_path, 'r') as f:
lines = f.readlines()
for line in lines:
data = line.strip().split(',')
features = list(map(float, data[:-1]))
label = int(data[-1])
self.samples.append(features)
self.labels.append(label)
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = torch.tensor(self.samples[idx])
label = torch.tensor(self.labels[idx])
return sample, label
# 定义模型类
class AlbertClassifier(nn.Module):
def __init__(self):
super(AlbertClassifier, self).__init__()
self.albert = AlbertModel.from_pretrained('albert-base-v2')
self.dropout = nn.Dropout(0.1)
self.fc = nn.Linear(self.albert.config.hidden_size, 8)
def forward(self, input_ids, attention_mask):
outputs = self.albert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
pooled_output = self.dropout(pooled_output)
logits = self.fc(pooled_output)
return logits
# 定义训练函数
def train(model, train_loader, val_loader, criterion, optimizer, num_epochs):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
best_val_acc = 0.0
for epoch in range(num_epochs):
model.train()
train_loss = 0.0
train_acc = 0.0
for inputs, labels in train_loader:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(inputs[:, :23], attention_mask=(inputs[:, :23] != 0))
loss = criterion(outputs, labels)
_, preds = torch.max(outputs, 1)
loss.backward()
optimizer.step()
train_loss += loss.item() * inputs.size(0)
train_acc += torch.sum(preds == labels.data)
train_loss = train_loss / len(train_loader.dataset)
train_acc = train_acc / len(train_loader.dataset)
model.eval()
val_loss = 0.0
val_acc = 0.0
with torch.no_grad():
for inputs, labels in val_loader:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs[:, :23], attention_mask=(inputs[:, :23] != 0))
loss = criterion(outputs, labels)
_, preds = torch.max(outputs, 1)
val_loss += loss.item() * inputs.size(0)
val_acc += torch.sum(preds == labels.data)
val_loss = val_loss / len(val_loader.dataset)
val_acc = val_acc / len(val_loader.dataset)
print('Epoch {}/{} - Train Loss: {:.4f} - Train Acc: {:.4f} - Val Loss: {:.4f} - Val Acc: {:.4f}'.format(
epoch+1, num_epochs, train_loss, train_acc, val_loss, val_acc))
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), 'best_model.pth')
print('Training complete!')
# 加载数据集
train_dataset = MyDataset('train.txt')
val_dataset = MyDataset('val.txt')
test_dataset = MyDataset('test.txt')
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
# 初始化模型、损失函数和优化器
model = AlbertClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练模型
train(model, train_loader, val_loader, criterion, optimizer, num_epochs=10)
# 加载最佳模型并进行测试
model.load_state_dict(torch.load('best_model.pth'))
model.eval()
test_acc = 0.0
with torch.no_grad():
for inputs, labels in test_loader:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs[:, :23], attention_mask=(inputs[:, :23] != 0))
_, preds = torch.max(outputs, 1)
test_acc += torch.sum(preds == labels.data)
test_acc = test_acc / len(test_loader.dataset)
print('Test Acc: {:.4f}'.format(test_acc))
请确保已安装以下依赖库:
- torch
- transformers
- numpy
在运行代码之前,需要将数据集文件按照训练集、验证集和测试集分别保存为'train.txt'、'val.txt'和'test.txt'文件。代码中使用的是'albert-base-v2'预训练模型,如果没有下载过该模型,需要先下载并保存到本地。训练过程中会保存在验证集上取得最佳准确率的模型参数到'best_model.pth'文件中,测试阶段将加载该模型进行预测并计算准确率。
数据集示例:
7,7,183,233,10,10,3,10,3,10,0,25,21,42,194,0,0,2,1,0,0,32,1.23,4
7,7,183,233,10,10,3,10,3,10,0,25,21,0,0,2,78,2,1,0,0,86.6685638427734,1.25,4
7,7,183,233,10,10,3,10,3,10,0,25,21,90,80,20,10,2,1,0,0,86.64013671875,1.30,0
7,7,183,233,10,10,3,10,3,10,0,25,21,90,80,20,10,2,1,0,0,86.4980087280273,1.10,0
7,0,183,0,9,0,3,10,3,0,0,25,123,90,80,20,10,0,1,0,1,0,1.00,7
7,0,183,0,9,0,3,10,3,0,0,25,123,90,80,20,10,0,1,0,1,0,1.00,7
...
每条样本前 23 位为特征值,最后一位为分类标签。
原文地址: https://www.cveoy.top/t/topic/qDFI 著作权归作者所有。请勿转载和采集!