以下是DQN的PyTorch实现代码:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
from collections import deque

class DQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return np.array(state), np.array(action), np.array(reward, dtype=np.float32), np.array(next_state), np.array(done, dtype=np.uint8)

    def __len__(self):
        return len(self.buffer)

class DQNAgent:
    def __init__(self, state_dim, action_dim, lr, gamma, epsilon, eps_decay, buffer_capacity, batch_size, target_update_freq):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.lr = lr
        self.gamma = gamma
        self.epsilon = epsilon
        self.eps_decay = eps_decay
        self.buffer = ReplayBuffer(buffer_capacity)
        self.batch_size = batch_size
        self.target_update_freq = target_update_freq
        self.steps = 0

        self.policy_net = DQN(state_dim, action_dim).to(self.device)
        self.target_net = DQN(state_dim, action_dim).to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()

        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
        self.loss_fn = nn.MSELoss()

    def act(self, state):
        if random.random() > self.epsilon:
            state = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device)
            q_value = self.policy_net(state)
            action = q_value.argmax(-1).item()
        else:
            action = random.choice(range(self.action_dim))
        return action

    def train(self):
        if len(self.buffer) < self.batch_size:
            return

        state, action, reward, next_state, done = self.buffer.sample(self.batch_size)

        state = torch.tensor(state, dtype=torch.float32).to(self.device)
        next_state = torch.tensor(next_state, dtype=torch.float32).to(self.device)
        action = torch.tensor(action, dtype=torch.int64).unsqueeze(-1).to(self.device)
        reward = torch.tensor(reward, dtype=torch.float32).unsqueeze(-1).to(self.device)
        done = torch.tensor(done, dtype=torch.uint8).unsqueeze(-1).to(self.device)

        q_values = self.policy_net(state).gather(-1, action)
        next_q_values = self.target_net(next_state).max(-1, keepdim=True)[0].detach()
        expected_q_values = reward + self.gamma * next_q_values * (1 - done)

        loss = self.loss_fn(q_values, expected_q_values)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.steps += 1
        self.epsilon = max(0.1, self.epsilon - self.eps_decay)

        if self.steps % self.target_update_freq == 0:
            self.target_net.load_state_dict(self.policy_net.state_dict())

    def store(self, state, action, reward, next_state, done):
        self.buffer.push(state, action, reward, next_state, done)

解释:

  • DQN类:定义了DQN的神经网络结构。该网络接受一个状态作为输入,输出每个可能的动作对应的Q值。

  • ReplayBuffer类:定义了经验回放缓冲区。该类维护一个双向队列,用于存储经验元组。

  • DQNAgent类:定义了DQN智能体。该类维护两个神经网络:policy_nettarget_net,以及一个经验回放缓冲区。其中policy_net用于预测Q值并生成动作,target_net用于计算目标Q值。智能体还实现了动作选择、训练和经验存储等方法。在每次训练时,智能体从经验回放缓冲区中随机采样一批经验元组,并使用目标Q值来更新策略Q值,从而实现学习。

  • act方法:选择动作的函数,接受一个状态作为输入,返回选择的动作。

  • train方法:训练DQN的函数。在每次训练时,智能体从经验回放缓冲区中随机采样一批经验元组,并使用目标Q值来更新策略Q值,从而实现学习。

  • store方法:将经验元组存储到经验回放缓冲区中。

  • target_update_freq:目标网络更新频率。智能体每训练target_update_freq次,就将策略网络的参数复制到目标网络中。这个技巧可以稳定训练并提高性能。

  • epsiloneps_decay:ε-贪心策略的参数。在每次动作选择时,智能体以概率ε随机选择一个动作,以概率1-ε选择当前最优动作。随着训练的进行,ε逐渐减小,从而使智能体逐渐趋向于选择当前最优动作。

DQN pytorch代码实现

原文地址: http://www.cveoy.top/t/topic/sbn 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录