以下是基于PyTorch的文本分类代码示例,以AG_NEWS数据集为例:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

# 获取数据集
train_iter = AG_NEWS(split='train')
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])

# 定义模型
class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_classes):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embedding_dim, sparse=True)
        self.fc = nn.Linear(embedding_dim, num_classes)

    def forward(self, text):
        embedded = self.embedding(text)
        return self.fc(embedded)

# 定义训练函数
def train(model, iterator, optimizer, criterion):
    model.train()

    for batch in iterator:
        text, labels = batch.text, batch.label
        optimizer.zero_grad()
        predictions = model(text)
        loss = criterion(predictions, labels)
        loss.backward()
        optimizer.step()

# 定义测试函数
def test(model, iterator, criterion):
    model.eval()
    correct = 0
    total = 0
    loss = 0

    with torch.no_grad():
        for batch in iterator:
            text, labels = batch.text, batch.label
            predictions = model(text)
            loss += criterion(predictions, labels).item()
            predicted_classes = torch.argmax(predictions, dim=1)
            correct += (predicted_classes == labels).sum().item()
            total += len(labels)

    accuracy = correct / total
    avg_loss = loss / len(iterator)
    return accuracy, avg_loss

# 训练参数
embedding_dim = 32
num_classes = len(train_iter.get_labels())
learning_rate = 0.01
num_epochs = 5
batch_size = 64

# 初始化模型、优化器和损失函数
model = TextClassificationModel(len(vocab), embedding_dim, num_classes)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

# 将数据集转换为可迭代的数据加载器
train_iter, test_iter = AG_NEWS()
train_loader = torch.utils.data.DataLoader(train_iter, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_iter, batch_size=batch_size, shuffle=True)

# 训练模型
for epoch in range(num_epochs):
    train(model, train_loader, optimizer, criterion)
    accuracy, avg_loss = test(model, test_loader, criterion)
    print(f'Epoch {epoch+1}: accuracy={accuracy:.4f}, loss={avg_loss:.4f}')

在这个示例中,我们使用了torchtext库中的AG_NEWS数据集,并使用基本英语分词器构建了词汇表。然后,我们定义了一个简单的文本分类模型,它将文本嵌入到一个嵌入层中,并在其上应用线性层以生成类别预测。我们还定义了训练和测试函数,以便我们可以训练模型并评估其性能。最后,我们使用Adam优化器和交叉熵损失函数来训练模型。

PyTorch文本分类代码实战:AG_NEWS数据集示例

原文地址: https://www.cveoy.top/t/topic/nPas 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录