在DQN中使用SE-ResNet:PyTorch代码示例
以下是在DQN中使用SE-ResNet的PyTorch代码示例:
import torch
import torch.nn as nn
import torch.optim as optim
import gym
from collections import deque
import random
class DQN(nn.Module):
def __init__(self, input_size, output_size, hidden_size=128):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, output_size)
self.se_layer = SELayer(hidden_size)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.se_layer(x)
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
class SELayer(nn.Module):
def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class ReplayBuffer:
def __init__(self, capacity):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
return torch.tensor(state), torch.tensor(action), torch.tensor(reward), torch.tensor(next_state), torch.tensor(done)
def __len__(self):
return len(self.buffer)
class Agent:
def __init__(self, env, buffer_capacity=10000, batch_size=64, gamma=0.99, eps_start=1.0, eps_end=0.01, eps_decay=0.995):
self.env = env
self.buffer = ReplayBuffer(buffer_capacity)
self.batch_size = batch_size
self.gamma = gamma
self.eps_start = eps_start
self.eps_end = eps_end
self.eps_decay = eps_decay
self.steps_done = 0
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.policy_net = DQN(env.observation_space.shape[0], env.action_space.n).to(self.device)
self.target_net = DQN(env.observation_space.shape[0], env.action_space.n).to(self.device)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.target_net.eval()
self.optimizer = optim.Adam(self.policy_net.parameters())
self.loss_fn = nn.MSELoss()
def act(self, state, eps=None):
if eps is None:
eps = self.eps_end + (self.eps_start - self.eps_end) * \
math.exp(-1. * self.steps_done / 200)
self.steps_done += 1
if random.random() < eps:
return self.env.action_space.sample()
state = torch.tensor([state], device=self.device, dtype=torch.float32)
with torch.no_grad():
q_values = self.policy_net(state)
return q_values.argmax().item()
def learn(self):
if len(self.buffer) < self.batch_size:
return
states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size)
states = states.float().to(self.device)
actions = actions.to(self.device)
rewards = rewards.to(self.device)
next_states = next_states.float().to(self.device)
dones = dones.to(self.device)
q_values = self.policy_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)
next_q_values = self.target_net(next_states).max(1)[0]
expected_q_values = rewards + (1 - dones) * self.gamma * next_q_values
loss = self.loss_fn(q_values, expected_q_values.detach())
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def update_target(self):
self.target_net.load_state_dict(self.policy_net.state_dict())
def train(self, num_episodes):
for episode in range(num_episodes):
state = self.env.reset()
total_reward = 0
done = False
while not done:
action = self.act(state)
next_state, reward, done, _ = self.env.step(action)
self.buffer.push(state, action, reward, next_state, done)
total_reward += reward
state = next_state
self.learn()
self.update_target()
print(f'Episode {episode+1}, Reward: {total_reward}')
env = gym.make('CartPole-v0')
agent = Agent(env)
agent.train(200)
本代码中,我们使用 SELayer 类来实现 SE 模块,它被添加到 DQN 网络中。SELayer 首先对特征图进行全局平均池化,然后使用两个全连接层和 sigmoid 激活函数来生成权重。这些权重乘以原始特征图,以强调重要的特征。
通过将 SE 模块添加到 DQN 网络,我们可以提高模型的性能,因为它能够更好地关注相关信息,并减少无关信息的干扰。
请注意,本代码仅供参考。您可能需要调整超参数和网络结构以适应您的具体问题。
原文地址: https://www.cveoy.top/t/topic/mZXI 著作权归作者所有。请勿转载和采集!