该代码实现了一个基于循环神经网络的中文拼音文本生成模型。具体来说,代码实现了以下功能:

  1. 读取并处理拼音文本数据,包括生成字符到索引和索引到字符的字典,以及实现对输入序列的'one-hot'编码。
import torch
import random
import numpy as np

with open('/kaggle/input/pinyin-data/pinyin.txt', 'r', encoding='utf-8') as f:
    pinyin_data = f.read().split('\n')

print(len(pinyin_data))
print(pinyin_data[:10])

char_to_idx = {}
idx_to_char = {}

for i, char in enumerate(set(''.join(pinyin_data))):
    char_to_idx[char] = i
    idx_to_char[i] = char

print(char_to_idx)
print(idx_to_char)

vocab_size = len(char_to_idx)
print('Vocab size:', vocab_size)

def one_hot_encode(sequence, vocab_size, dtype=torch.float32):
    sequence_length = len(sequence)
    one_hot = torch.zeros(sequence_length, vocab_size, dtype=dtype)
    for i, char in enumerate(sequence):
        char_idx = char_to_idx[char]
        one_hot[i, char_idx] = 1
    return one_hot

test_sequence = 'ni hao'
test_one_hot = one_hot_encode(test_sequence, vocab_size)
print(test_one_hot)
  1. 实现两种数据批量加载方式,即随机采样和顺序划分。
def data_loader_random(text, batch_size, num_steps):
    corpus = text.replace('\n', ' ').replace('\r', ' ')
    corpus_one_hot = one_hot_encode(corpus, vocab_size)
    num_batches = corpus_one_hot.shape[0] // (batch_size * num_steps)
    corpus_one_hot = corpus_one_hot[:num_batches * batch_size * num_steps]
    corpus_one_hot = corpus_one_hot.view(batch_size, num_batches * num_steps, vocab_size)
    while True:
        for i in range(num_batches):
            batch = corpus_one_hot[:, i * num_steps : (i + 1) * num_steps, :]
            inputs = batch[:, :-1, :]
            targets = batch[:, 1:, :]
            yield inputs, targets

def data_loader_sequential(text, batch_size, num_steps):
    corpus = text.replace('\n', ' ').replace('\r', ' ')
    corpus_one_hot = one_hot_encode(corpus, vocab_size)
    num_batches = (corpus_one_hot.shape[0] - 1) // (batch_size * num_steps)
    corpus_one_hot = corpus_one_hot[:num_batches * batch_size * num_steps + 1]
    corpus_one_hot = corpus_one_hot.view(batch_size, -1, num_steps, vocab_size)
    while True:
        for i in range(num_batches):
            batch = corpus_one_hot[:, i * num_steps : (i + 1) * num_steps + 1, :, :]
            inputs = batch[:, :-1, :, :]
            targets = batch[:, 1:, :, :]
            yield inputs, targets
  1. 定义了一个基于循环神经网络的生成模型,可以选择使用RNN或GRU作为循环层。
class RNNModel(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers, rnn_type='rnn', device='cpu'):
        super(RNNModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn_type = rnn_type
        self.device = device
        if rnn_type == 'rnn':
            self.rnn = torch.nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        elif rnn_type == 'gru':
            self.rnn = torch.nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        else:
            raise ValueError("Invalid RNN type: " + rnn_type)
        self.fc = torch.nn.Linear(hidden_size, output_size)

    def forward(self, x, h=None):
        batch_size = x.size(0)
        if h is None:
            h = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(self.device)
        if self.rnn_type == 'rnn':
            out, h = self.rnn(x, h)
        elif self.rnn_type == 'gru':
            out, h = self.rnn(x, h)
        else:
            raise ValueError("Invalid RNN type: " + self.rnn_type)
        out = out.contiguous().view(-1, self.hidden_size)
        out = self.fc(out)
        return out, h
  1. 实现对生成模型的训练和预测功能,包括训练过程中的梯度裁剪和预测时的字符采样。
def train(model, data_loader, num_epochs, learning_rate, grad_clip, device='cpu'):
    model.to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    for epoch in range(num_epochs):
        total_loss = 0.0
        n = 0
        for inputs, targets in data_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            batch_size = inputs.size(0)
            num_steps = inputs.size(1)
            h = None
            loss = 0.0
            model.zero_grad()
            for i in range(num_steps):
                x = inputs[:, i, :]
                y = targets[:, i, :]
                y_pred, h = model(x.unsqueeze(1), h)
                loss += criterion(y_pred, y.view(-1))
            loss /= num_steps
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()
            total_loss += loss.item() * batch_size
            n += batch_size
        print("Epoch %d, Loss: %.4f" % (epoch + 1, total_loss / n))

def predict(model, prefix, num_predictions, device='cpu'):
    model.to(device)
    with torch.no_grad():
        prefix_one_hot = one_hot_encode(prefix, vocab_size).to(device)
        _, h = model(prefix_one_hot.unsqueeze(0))
        y_pred = prefix[-1]
        for _ in range(num_predictions):
            x = one_hot_encode(y_pred, vocab_size).to(device)
            y_pred_one_hot, h = model(x.unsqueeze(0), h)
            y_pred = idx_to_char[y_pred_one_hot.argmax().item()]
            prefix += y_pred
        return prefix

# 随机采样
batch_size = 32
num_steps = 10
data_loader = data_loader_random('\n'.join(pinyin_data), batch_size, num_steps)
model = RNNModel(vocab_size, 128, vocab_size, 2, rnn_type='gru')
train(model, data_loader, num_epochs=10, learning_rate=0.001, grad_clip=1.0,device='gpu')

# 测试
prefix = 'ni'
num_predictions = 10
print(predict(model, prefix, num_predictions))

# 顺序划分
batch_size = 32
num_steps = 10
data_loader = data_loader_sequential('\n'.join(pinyin_data), batch_size, num_steps)
model = RNNModel(vocab_size, 128, vocab_size, 2, rnn_type='gru')
train(model, data_loader, num_epochs=10, learning_rate=0.001, grad_clip=1.0,device='gpu')

# 测试
prefix = 'ni'
num_predictions = 10
print(predict(model, prefix, num_predictions))

总的来说,代码实现了一个基本的中文拼音文本生成模型,并提供了两种数据批量加载方式和两种循环神经网络模型选择。需要注意的是,在训练和预测时需要选择合适的设备(GPU或CPU),以便提高运行速度和减少内存占用。

基于循环神经网络的中文拼音文本生成模型

原文地址: https://www.cveoy.top/t/topic/onEn 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录