写一个基于pytorch的TrajGRU模型
import torch import torch.nn as nn
class TrajGRU(nn.Module): def init(self, input_dim, hidden_dim, num_layers, kernel_size, dropout): super(TrajGRU, self).init() self.input_dim = input_dim self.hidden_dim = hidden_dim self.num_layers = num_layers self.kernel_size = kernel_size self.dropout = dropout
self.conv_xz = nn.Conv2d(in_channels=input_dim, out_channels=hidden_dim, kernel_size=kernel_size, padding=kernel_size//2)
self.conv_hz = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=kernel_size, padding=kernel_size//2)
self.conv_xr = nn.Conv2d(in_channels=input_dim, out_channels=hidden_dim, kernel_size=kernel_size, padding=kernel_size//2)
self.conv_hr = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=kernel_size, padding=kernel_size//2)
self.conv_xn = nn.Conv2d(in_channels=input_dim, out_channels=hidden_dim, kernel_size=kernel_size, padding=kernel_size//2)
self.conv_hn = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=kernel_size, padding=kernel_size//2)
self.dropout_layer = nn.Dropout2d(p=dropout)
def forward(self, x, h):
z = torch.sigmoid(self.conv_xz(x) + self.conv_hz(h))
r = torch.sigmoid(self.conv_xr(x) + self.conv_hr(h))
n = torch.tanh(self.conv_xn(x) + self.conv_hn(r * h))
h_tilde = (1 - z) * n + z * h
h_tilde = self.dropout_layer(h_tilde)
return h_tilde
class TrajGRUNet(nn.Module): def init(self, input_dim, hidden_dim, num_layers, kernel_size, dropout): super(TrajGRUNet, self).init() self.input_dim = input_dim self.hidden_dim = hidden_dim self.num_layers = num_layers self.kernel_size = kernel_size self.dropout = dropout
self.trajgru_layers = nn.ModuleList()
for i in range(num_layers):
if i == 0:
self.trajgru_layers.append(TrajGRU(input_dim, hidden_dim, num_layers, kernel_size, dropout))
else:
self.trajgru_layers.append(TrajGRU(hidden_dim, hidden_dim, num_layers, kernel_size, dropout))
def forward(self, x):
b, t, c, h, w = x.size()
x = x.view(b * t, c, h, w)
h = x.new_zeros((b, self.hidden_dim, h, w))
for i in range(self.num_layers):
h = self.trajgru_layers[i](x, h)
h = h.view(b, t, self.hidden_dim, h, w)
return
原文地址: https://www.cveoy.top/t/topic/fdwa 著作权归作者所有。请勿转载和采集!