PyTorch文本分类代码实战:AG_NEWS数据集示例
以下是基于PyTorch的文本分类代码示例,以AG_NEWS数据集为例:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
# 获取数据集
train_iter = AG_NEWS(split='train')
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])
# 定义模型
class TextClassificationModel(nn.Module):
def __init__(self, vocab_size, embedding_dim, num_classes):
super(TextClassificationModel, self).__init__()
self.embedding = nn.EmbeddingBag(vocab_size, embedding_dim, sparse=True)
self.fc = nn.Linear(embedding_dim, num_classes)
def forward(self, text):
embedded = self.embedding(text)
return self.fc(embedded)
# 定义训练函数
def train(model, iterator, optimizer, criterion):
model.train()
for batch in iterator:
text, labels = batch.text, batch.label
optimizer.zero_grad()
predictions = model(text)
loss = criterion(predictions, labels)
loss.backward()
optimizer.step()
# 定义测试函数
def test(model, iterator, criterion):
model.eval()
correct = 0
total = 0
loss = 0
with torch.no_grad():
for batch in iterator:
text, labels = batch.text, batch.label
predictions = model(text)
loss += criterion(predictions, labels).item()
predicted_classes = torch.argmax(predictions, dim=1)
correct += (predicted_classes == labels).sum().item()
total += len(labels)
accuracy = correct / total
avg_loss = loss / len(iterator)
return accuracy, avg_loss
# 训练参数
embedding_dim = 32
num_classes = len(train_iter.get_labels())
learning_rate = 0.01
num_epochs = 5
batch_size = 64
# 初始化模型、优化器和损失函数
model = TextClassificationModel(len(vocab), embedding_dim, num_classes)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
# 将数据集转换为可迭代的数据加载器
train_iter, test_iter = AG_NEWS()
train_loader = torch.utils.data.DataLoader(train_iter, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_iter, batch_size=batch_size, shuffle=True)
# 训练模型
for epoch in range(num_epochs):
train(model, train_loader, optimizer, criterion)
accuracy, avg_loss = test(model, test_loader, criterion)
print(f'Epoch {epoch+1}: accuracy={accuracy:.4f}, loss={avg_loss:.4f}')
在这个示例中,我们使用了torchtext库中的AG_NEWS数据集,并使用基本英语分词器构建了词汇表。然后,我们定义了一个简单的文本分类模型,它将文本嵌入到一个嵌入层中,并在其上应用线性层以生成类别预测。我们还定义了训练和测试函数,以便我们可以训练模型并评估其性能。最后,我们使用Adam优化器和交叉熵损失函数来训练模型。
原文地址: https://www.cveoy.top/t/topic/nPas 著作权归作者所有。请勿转载和采集!