Python实现RNN循环神经网络:代码详解及示例
以下是使用Python实现RNN(循环神经网络)的基本结构,包括前向传播和反向传播的代码:
import numpy as np
# 定义RNN类
class RNN:
def __init__(self, input_size, hidden_size, output_size):
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
# 初始化权重和偏置
self.Wxh = np.random.randn(hidden_size, input_size) * 0.01
self.Whh = np.random.randn(hidden_size, hidden_size) * 0.01
self.Why = np.random.randn(output_size, hidden_size) * 0.01
self.bh = np.zeros((hidden_size, 1))
self.by = np.zeros((output_size, 1))
# 前向传播
def forward(self, x):
# 初始化隐藏状态
h = np.zeros((self.hidden_size, 1))
# 遍历时间步
for t in range(len(x)):
# 更新隐藏状态
h = np.tanh(self.Wxh.dot(x[t]) + self.Whh.dot(h) + self.bh)
# 计算输出
y = self.Why.dot(h) + self.by
return y, h
# 反向传播
def backward(self, x, y_true, y_pred, h):
# 初始化梯度
dWxh = np.zeros_like(self.Wxh)
dWhh = np.zeros_like(self.Whh)
dWhy = np.zeros_like(self.Why)
dbh = np.zeros_like(self.bh)
dby = np.zeros_like(self.by)
dhnext = np.zeros_like(h)
# 计算输出误差
dy = y_pred - y_true
# 反向传播时间步
for t in reversed(range(len(x))):
# 计算当前时间步的梯度
dWhy += dy.dot(h.T)
dby += dy
dh = self.Why.T.dot(dy) + dhnext
dhraw = (1 - h ** 2) * dh
dbh += dhraw
dWxh += dhraw.dot(x[t].T)
dWhh += dhraw.dot(h.T)
dhnext = self.Whh.T.dot(dhraw)
# 梯度裁剪
for dparam in [dWxh, dWhh, dWhy, dbh, dby]:
np.clip(dparam, -5, 5, out=dparam)
# 更新权重和偏置
self.Wxh -= dWxh * learning_rate
self.Whh -= dWhh * learning_rate
self.Why -= dWhy * learning_rate
self.bh -= dbh * learning_rate
self.by -= dby * learning_rate
# 定义超参数
input_size = 3
hidden_size = 4
output_size = 2
learning_rate = 0.1
# 初始化RNN
RNN = RNN(input_size, hidden_size, output_size)
# 定义输入和标签
x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
y_true = np.array([[0, 1], [1, 0], [0, 1], [1, 0]])
# 训练RNN
for i in range(1000):
# 前向传播
y_pred, h = rnn.forward(x)
# 计算损失
loss = np.sum((y_pred - y_true) ** 2)
# 反向传播
rnn.backward(x, y_true, y_pred, h)
# 每100次迭代输出一次损失
if i % 100 == 0:
print('Iteration', i, 'Loss', loss)
以上代码实现了一个简单的RNN,可以通过修改超参数和输入数据来训练不同的模型。
原文地址: https://www.cveoy.top/t/topic/nzd9 著作权归作者所有。请勿转载和采集!