以下是 DQN 算法在 CartPole-v1 环境下的 PyTorch 实现代码:

import gym
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque


class DQNAgent(nn.Module):
    def __init__(self, env, gamma=0.99, epsilon=1.0, epsilon_min=0.01, epsilon_decay=0.995, lr=0.001, batch_size=64, memory_size=10000):
        super(DQNAgent, self).__init__()
        self.env = env
        self.observation_space = env.observation_space.shape[0]
        self.action_space = env.action_space.n
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.lr = lr
        self.batch_size = batch_size
        self.memory = deque(maxlen=memory_size)

        self.fc1 = nn.Linear(self.observation_space, 24)
        self.fc2 = nn.Linear(24, 24)
        self.fc3 = nn.Linear(24, self.action_space)

        self.optimizer = optim.Adam(self.parameters(), lr=self.lr)
        self.loss_function = nn.MSELoss()

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        if random.uniform(0, 1) < self.epsilon:
            return self.env.action_space.sample()

        state = torch.FloatTensor(state)
        q_value = self.forward(state)
        action = torch.argmax(q_value).item()

        return action

    def replay(self):
        if len(self.memory) < self.batch_size:
            return

        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions).view(-1, 1)
        rewards = torch.FloatTensor(rewards).view(-1, 1)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones).view(-1, 1)

        q_values = self.forward(states).gather(1, actions)
        next_q_values = self.forward(next_states).detach().max(1)[0].view(-1, 1)
        target_q_values = rewards + (1 - dones) * self.gamma * next_q_values

        loss = self.loss_function(q_values, target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay


if __name__ == '__main__':
    env = gym.make('CartPole-v1')
    agent = DQNAgent(env)

    episodes = 1000
    for episode in range(episodes):
        state = env.reset()
        done = False
        score = 0

        while not done:
            action = agent.act(state)
            next_state, reward, done, _ = env.step(action)
            score += reward
            agent.remember(state, action, reward, next_state, done)
            state = next_state
            agent.replay()

        print(f'Episode: {episode + 1}, Score: {score}, Epsilon: {agent.epsilon:.2f}')

    env.close()

在运行时,算法会在每个 episode 中与环境交互,在每个步骤中根据当前状态选择一个动作,并将状态、动作、奖励、下一个状态和完成标志存储在记忆库中。然后,从记忆库中随机选择一批样本进行训练,以更新 Q 网络的权重。同时,随着时间的推移,智能体会逐渐减少探索,增加利用,以提高算法的性能。


原文地址: https://www.cveoy.top/t/topic/lFG4 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录