假设搜索区域维度为217681818模板区域为2276888其中第一个维度代表batchsize第二个维度代表num_frames第三个维度代表嵌入维度如何利用timesformer中的时空分离注意力进行搜索区域和模板区域的时间信息交互给出具体代码尤其是不同大小的图像之间的时间注意力
首先,需要将搜索区域和模板区域的时间维度分离出来,以便于在时间维度上进行注意力交互。可以使用PyTorch中的transpose函数对输入的张量进行转置操作。
代码如下:
search_area = search_area.transpose(1, 3).reshape(batch_size, num_frames, -1, height * width)
template_area = template_area.transpose(1, 3).reshape(batch_size, num_frames, -1, height * width // 4)
其中,transpose(1, 3)表示将第1维和第3维进行交换,即将时间维度从第2维移动到第4维,reshape函数则将图像的高和宽维度合并为一个维度,以便于在时间维度上进行注意力交互。
接下来,可以利用TimeSformer中的时空分离注意力模块进行搜索区域和模板区域的时间信息交互。时空分离注意力模块由时间注意力和空间注意力组成,其中时间注意力用于在时间维度上进行交互,空间注意力用于在空间维度上进行交互。
代码如下:
from timm.models.layers import SpaceToDepth, DepthToSpace, Mlp
class TimeSformerAttention(nn.Module):
def __init__(self, embed_dim, num_heads, qkv_bias=False, attn_dropout=0., proj_dropout=0.):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_dropout)
self.proj = nn.Linear(embed_dim, embed_dim)
self.proj_drop = nn.Dropout(proj_dropout)
def forward(self, x):
B, F, N, L = x.shape
qkv = self.qkv(x).reshape(B, F, 3, self.num_heads, N, self.head_dim).permute(2, 0, 3, 4, 1, 5)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, F, N, self.embed_dim)
x = self.proj(x)
x = self.proj_drop(x)
return x
class TimeSformer(nn.Module):
def __init__(self, embed_dim=768, num_heads=12, depth=12, mlp_ratio=4, qkv_bias=False, attn_dropout=0., proj_dropout=0.):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.depth = depth
# spatial to depth for space-to-time transformation
self.space_to_depth = SpaceToDepth(block_size=num_heads)
# time attention layers
self.time_attns = nn.ModuleList([])
for i in range(depth):
self.time_attns.append(TimeSformerAttention(embed_dim, num_heads, qkv_bias, attn_dropout, proj_dropout))
# spatial attention layers
self.spatial_attns = nn.ModuleList([])
for i in range(depth):
self.spatial_attns.append(nn.Sequential(
nn.LayerNorm(embed_dim),
Mlp(embed_dim, int(embed_dim * mlp_ratio), embed_dim),
))
# depth to spatial for time-to-space transformation
self.depth_to_space = DepthToSpace(block_size=num_heads)
def forward(self, x):
B, F, N, L = x.shape
# space-to-time transformation
x = self.space_to_depth(x)
# time attention layers
for i in range(self.depth):
x = x + self.time_attns[i](x)
# spatial attention layers
x = self.spatial_attns[0](x)
for i in range(1, self.depth):
x = self.depth_to_space(x)
x = x + self.spatial_attns[i](x)
# time-to-space transformation
x = self.depth_to_space(x)
return x
注意,上述代码中,TimeSformerAttention模块用于在时间维度上进行交互,而TimeSformer模块则将其和空间维度上的交互结合起来,形成完整的时空分离注意力模块。
最后,可以将交互后的搜索区域和模板区域重新转置回原来的维度,并reshape回原来的形状。
代码如下:
search_area = search_area.reshape(batch_size, num_frames, -1, height, width).transpose(1, 3)
template_area = template_area.reshape(batch_size, num_frames, -1, height // 2, width // 2).transpose(1, 3)
完整代码如下:
from timm.models.layers import SpaceToDepth, DepthToSpace, Mlp
class TimeSformerAttention(nn.Module):
def __init__(self, embed_dim, num_heads, qkv_bias=False, attn_dropout=0., proj_dropout=0.):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_dropout)
self.proj = nn.Linear(embed_dim, embed_dim)
self.proj_drop = nn.Dropout(proj_dropout)
def forward(self, x):
B, F, N, L = x.shape
qkv = self.qkv(x).reshape(B, F, 3, self.num_heads, N, self.head_dim).permute(2, 0, 3, 4, 1, 5)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, F, N, self.embed_dim)
x = self.proj(x)
x = self.proj_drop(x)
return x
class TimeSformer(nn.Module):
def __init__(self, embed_dim=768, num_heads=12, depth=12, mlp_ratio=4, qkv_bias=False, attn_dropout=0., proj_dropout=0.):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.depth = depth
# spatial to depth for space-to-time transformation
self.space_to_depth = SpaceToDepth(block_size=num_heads)
# time attention layers
self.time_attns = nn.ModuleList([])
for i in range(depth):
self.time_attns.append(TimeSformerAttention(embed_dim, num_heads, qkv_bias, attn_dropout, proj_dropout))
# spatial attention layers
self.spatial_attns = nn.ModuleList([])
for i in range(depth):
self.spatial_attns.append(nn.Sequential(
nn.LayerNorm(embed_dim),
Mlp(embed_dim, int(embed_dim * mlp_ratio), embed_dim),
))
# depth to spatial for time-to-space transformation
self.depth_to_space = DepthToSpace(block_size=num_heads)
def forward(self, x):
B, F, N, L = x.shape
# space-to-time transformation
x = self.space_to_depth(x)
# time attention layers
for i in range(self.depth):
x = x + self.time_attns[i](x)
# spatial attention layers
x = self.spatial_attns[0](x)
for i in range(1, self.depth):
x = self.depth_to_space(x)
x = x + self.spatial_attns[i](x)
# time-to-space transformation
x = self.depth_to_space(x)
return x
# search area and template area shapes
batch_size = 2
num_frames = 1
embed_dim = 768
search_height, search_width = 18, 18
template_height, template_width = 8, 8
# create random search area and template area tensors
search_area = torch.randn(batch_size, num_frames, embed_dim, search_height, search_width)
template_area = torch.randn(batch_size, num_frames, embed_dim, template_height, template_width)
# separate time dimension
search_area = search_area.transpose(1, 3).reshape(batch_size, num_frames, -1, search_height * search_width)
template_area = template_area.transpose(1, 3).reshape(batch_size, num_frames, -1, template_height * template_width // 4)
# apply time-space attention
tsf = TimeSformer(embed_dim=embed_dim)
search_area = tsf(search_area)
template_area = tsf(template_area)
# reshape and transpose back
search_area = search_area.reshape(batch_size, search_height, search_width, num_frames, embed_dim).transpose(1, 3)
template_area = template_area.reshape(batch_size, template_height, template_width, num_frames, embed_dim).transpose(1, 3)
原文地址: http://www.cveoy.top/t/topic/buS8 著作权归作者所有。请勿转载和采集!