首先,需要将搜索区域和模板区域的时间维度分离出来,以便于在时间维度上进行注意力交互。可以使用PyTorch中的transpose函数对输入的张量进行转置操作。

代码如下:

search_area = search_area.transpose(1, 3).reshape(batch_size, num_frames, -1, height * width)
template_area = template_area.transpose(1, 3).reshape(batch_size, num_frames, -1, height * width // 4)

其中,transpose(1, 3)表示将第1维和第3维进行交换,即将时间维度从第2维移动到第4维,reshape函数则将图像的高和宽维度合并为一个维度,以便于在时间维度上进行注意力交互。

接下来,可以利用TimeSformer中的时空分离注意力模块进行搜索区域和模板区域的时间信息交互。时空分离注意力模块由时间注意力和空间注意力组成,其中时间注意力用于在时间维度上进行交互,空间注意力用于在空间维度上进行交互。

代码如下:

from timm.models.layers import SpaceToDepth, DepthToSpace, Mlp

class TimeSformerAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, qkv_bias=False, attn_dropout=0., proj_dropout=0.):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_dropout)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(proj_dropout)

    def forward(self, x):
        B, F, N, L = x.shape
        qkv = self.qkv(x).reshape(B, F, 3, self.num_heads, N, self.head_dim).permute(2, 0, 3, 4, 1, 5)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, F, N, self.embed_dim)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class TimeSformer(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, depth=12, mlp_ratio=4, qkv_bias=False, attn_dropout=0., proj_dropout=0.):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.depth = depth

        # spatial to depth for space-to-time transformation
        self.space_to_depth = SpaceToDepth(block_size=num_heads)

        # time attention layers
        self.time_attns = nn.ModuleList([])
        for i in range(depth):
            self.time_attns.append(TimeSformerAttention(embed_dim, num_heads, qkv_bias, attn_dropout, proj_dropout))

        # spatial attention layers
        self.spatial_attns = nn.ModuleList([])
        for i in range(depth):
            self.spatial_attns.append(nn.Sequential(
                nn.LayerNorm(embed_dim),
                Mlp(embed_dim, int(embed_dim * mlp_ratio), embed_dim),
            ))

        # depth to spatial for time-to-space transformation
        self.depth_to_space = DepthToSpace(block_size=num_heads)

    def forward(self, x):
        B, F, N, L = x.shape

        # space-to-time transformation
        x = self.space_to_depth(x)

        # time attention layers
        for i in range(self.depth):
            x = x + self.time_attns[i](x)

        # spatial attention layers
        x = self.spatial_attns[0](x)
        for i in range(1, self.depth):
            x = self.depth_to_space(x)
            x = x + self.spatial_attns[i](x)

        # time-to-space transformation
        x = self.depth_to_space(x)

        return x

注意,上述代码中,TimeSformerAttention模块用于在时间维度上进行交互,而TimeSformer模块则将其和空间维度上的交互结合起来,形成完整的时空分离注意力模块。

最后,可以将交互后的搜索区域和模板区域重新转置回原来的维度,并reshape回原来的形状。

代码如下:

search_area = search_area.reshape(batch_size, num_frames, -1, height, width).transpose(1, 3)
template_area = template_area.reshape(batch_size, num_frames, -1, height // 2, width // 2).transpose(1, 3)

完整代码如下:

from timm.models.layers import SpaceToDepth, DepthToSpace, Mlp

class TimeSformerAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, qkv_bias=False, attn_dropout=0., proj_dropout=0.):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_dropout)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(proj_dropout)

    def forward(self, x):
        B, F, N, L = x.shape
        qkv = self.qkv(x).reshape(B, F, 3, self.num_heads, N, self.head_dim).permute(2, 0, 3, 4, 1, 5)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, F, N, self.embed_dim)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class TimeSformer(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, depth=12, mlp_ratio=4, qkv_bias=False, attn_dropout=0., proj_dropout=0.):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.depth = depth

        # spatial to depth for space-to-time transformation
        self.space_to_depth = SpaceToDepth(block_size=num_heads)

        # time attention layers
        self.time_attns = nn.ModuleList([])
        for i in range(depth):
            self.time_attns.append(TimeSformerAttention(embed_dim, num_heads, qkv_bias, attn_dropout, proj_dropout))

        # spatial attention layers
        self.spatial_attns = nn.ModuleList([])
        for i in range(depth):
            self.spatial_attns.append(nn.Sequential(
                nn.LayerNorm(embed_dim),
                Mlp(embed_dim, int(embed_dim * mlp_ratio), embed_dim),
            ))

        # depth to spatial for time-to-space transformation
        self.depth_to_space = DepthToSpace(block_size=num_heads)

    def forward(self, x):
        B, F, N, L = x.shape

        # space-to-time transformation
        x = self.space_to_depth(x)

        # time attention layers
        for i in range(self.depth):
            x = x + self.time_attns[i](x)

        # spatial attention layers
        x = self.spatial_attns[0](x)
        for i in range(1, self.depth):
            x = self.depth_to_space(x)
            x = x + self.spatial_attns[i](x)

        # time-to-space transformation
        x = self.depth_to_space(x)

        return x

# search area and template area shapes
batch_size = 2
num_frames = 1
embed_dim = 768
search_height, search_width = 18, 18
template_height, template_width = 8, 8

# create random search area and template area tensors
search_area = torch.randn(batch_size, num_frames, embed_dim, search_height, search_width)
template_area = torch.randn(batch_size, num_frames, embed_dim, template_height, template_width)

# separate time dimension
search_area = search_area.transpose(1, 3).reshape(batch_size, num_frames, -1, search_height * search_width)
template_area = template_area.transpose(1, 3).reshape(batch_size, num_frames, -1, template_height * template_width // 4)

# apply time-space attention
tsf = TimeSformer(embed_dim=embed_dim)
search_area = tsf(search_area)
template_area = tsf(template_area)

# reshape and transpose back
search_area = search_area.reshape(batch_size, search_height, search_width, num_frames, embed_dim).transpose(1, 3)
template_area = template_area.reshape(batch_size, template_height, template_width, num_frames, embed_dim).transpose(1, 3)
假设搜索区域维度为217681818模板区域为2276888其中第一个维度代表batchsize第二个维度代表num_frames第三个维度代表嵌入维度如何利用timesformer中的时空分离注意力进行搜索区域和模板区域的时间信息交互给出具体代码尤其是不同大小的图像之间的时间注意力

原文地址: http://www.cveoy.top/t/topic/buS8 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录