将如下代码生成为正确可运行的代码import torch import torchnn as nn import time # Press Shift+F10 to execute it or replace it with your code # Press Double Shift to search everywhere for classes files tool windows acti
import torch import torch.nn as nn import time
class SELayer(nn.Module):
def init(self, channel, reduction=16):
super(SELayer, self).init()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class SEBasicBlock(nn.Module):
expansion = 1
def init(self, in_channel, out_channel, stride=1, downsample=None,reduction=8):
super(SEBasicBlock, self).init()
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channel)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channel)
# self.se = SELayer(out_channel, reduction)
self.downsample = downsample
def forward(self, x):
identity = x
if self.downsample is not None:
identity = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
# out = self.se(out)
# print(out.shape, identity.shape)
out += identity
out = self.relu(out)
return out
class DQN(nn.Module):
def init(self, in_planes, outputs, stride=1, hidden_size=128):
super(DQN, self).init()
################################################
## 残差+ se注意力 ##
################################################
self.conv1 = self._make_layer(SEBasicBlock, in_planes, 256, 2, stride=stride)
self.conv2 = self._make_layer(SEBasicBlock, 256, 128, 2, stride=stride)
self.conv3 = self._make_layer(SEBasicBlock, 128, 64, 2, stride=stride)
self.fc1 = nn.Linear(6444, 128)
self.fc2 = nn.Linear(128, 256)
self.val_hidden = nn.Linear(256, 128)
self.adv_hidden = nn.Linear(256,128)
self.val = nn.Linear(128, 1)
self.adv = nn.Linear(128, outputs)
def _make_layer(self, block, in_channel, out_channel, num_blocks, stride):
downsample = None
if stride != 1 or in_channel != out_channel:
downsample = nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channel),
)
layers = []
layers.append(block(in_channel, out_channel, stride, downsample))
in_channel = out_channel
for i in range(1, num_blocks):
layers.append(block(in_channel, out_channel))
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
out = self.conv3(out)
out = out.view(out.size(0), -1)
out = self.fc1(out)
out = self.fc2(out)
adv = self.adv_hidden(out)
adv = self.adv(adv)
val = self.val_hidden(out)
val = self.val(val)
q = val + adv - adv.mean(dim=1, keepdim=True)
return q
def print_hi(name):
# Use a breakpoint in the code line below to debug your script.
print(f'Hi, {name}')
if name == 'main':
print_hi('PyCharm')
net = DQN(1,3)
net.eval()
x = torch.randn(1, 1, 16, 16)
start_time = time.time()
y = net(x)
end_time = time.time()
print('Output size:', y.size())
print('Time elapsed:', end_time - start_time)
原文地址: https://www.cveoy.top/t/topic/bzcS 著作权归作者所有。请勿转载和采集!