DQN pytorch代码实现
以下是DQN的PyTorch实现代码:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
from collections import deque
class DQN(nn.Module):
def __init__(self, state_dim, action_dim):
super().__init__()
self.fc1 = nn.Linear(state_dim, 64)
self.fc2 = nn.Linear(64, 64)
self.fc3 = nn.Linear(64, action_dim)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
class ReplayBuffer:
def __init__(self, capacity):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
return np.array(state), np.array(action), np.array(reward, dtype=np.float32), np.array(next_state), np.array(done, dtype=np.uint8)
def __len__(self):
return len(self.buffer)
class DQNAgent:
def __init__(self, state_dim, action_dim, lr, gamma, epsilon, eps_decay, buffer_capacity, batch_size, target_update_freq):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.state_dim = state_dim
self.action_dim = action_dim
self.lr = lr
self.gamma = gamma
self.epsilon = epsilon
self.eps_decay = eps_decay
self.buffer = ReplayBuffer(buffer_capacity)
self.batch_size = batch_size
self.target_update_freq = target_update_freq
self.steps = 0
self.policy_net = DQN(state_dim, action_dim).to(self.device)
self.target_net = DQN(state_dim, action_dim).to(self.device)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.target_net.eval()
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
self.loss_fn = nn.MSELoss()
def act(self, state):
if random.random() > self.epsilon:
state = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device)
q_value = self.policy_net(state)
action = q_value.argmax(-1).item()
else:
action = random.choice(range(self.action_dim))
return action
def train(self):
if len(self.buffer) < self.batch_size:
return
state, action, reward, next_state, done = self.buffer.sample(self.batch_size)
state = torch.tensor(state, dtype=torch.float32).to(self.device)
next_state = torch.tensor(next_state, dtype=torch.float32).to(self.device)
action = torch.tensor(action, dtype=torch.int64).unsqueeze(-1).to(self.device)
reward = torch.tensor(reward, dtype=torch.float32).unsqueeze(-1).to(self.device)
done = torch.tensor(done, dtype=torch.uint8).unsqueeze(-1).to(self.device)
q_values = self.policy_net(state).gather(-1, action)
next_q_values = self.target_net(next_state).max(-1, keepdim=True)[0].detach()
expected_q_values = reward + self.gamma * next_q_values * (1 - done)
loss = self.loss_fn(q_values, expected_q_values)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.steps += 1
self.epsilon = max(0.1, self.epsilon - self.eps_decay)
if self.steps % self.target_update_freq == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
def store(self, state, action, reward, next_state, done):
self.buffer.push(state, action, reward, next_state, done)
解释:
-
DQN
类:定义了DQN的神经网络结构。该网络接受一个状态作为输入,输出每个可能的动作对应的Q值。 -
ReplayBuffer
类:定义了经验回放缓冲区。该类维护一个双向队列,用于存储经验元组。 -
DQNAgent
类:定义了DQN智能体。该类维护两个神经网络:policy_net
和target_net
,以及一个经验回放缓冲区。其中policy_net
用于预测Q值并生成动作,target_net
用于计算目标Q值。智能体还实现了动作选择、训练和经验存储等方法。在每次训练时,智能体从经验回放缓冲区中随机采样一批经验元组,并使用目标Q值来更新策略Q值,从而实现学习。 -
act
方法:选择动作的函数,接受一个状态作为输入,返回选择的动作。 -
train
方法:训练DQN的函数。在每次训练时,智能体从经验回放缓冲区中随机采样一批经验元组,并使用目标Q值来更新策略Q值,从而实现学习。 -
store
方法:将经验元组存储到经验回放缓冲区中。 -
target_update_freq
:目标网络更新频率。智能体每训练target_update_freq
次,就将策略网络的参数复制到目标网络中。这个技巧可以稳定训练并提高性能。 -
epsilon
和eps_decay
:ε-贪心策略的参数。在每次动作选择时,智能体以概率ε随机选择一个动作,以概率1-ε选择当前最优动作。随着训练的进行,ε逐渐减小,从而使智能体逐渐趋向于选择当前最优动作。

原文地址: http://www.cveoy.top/t/topic/sbn 著作权归作者所有。请勿转载和采集!