import\u0020torch\nimport\u0020torch.nn\u0020as\u0020nn\nimport\u0020torch.nn.functional\u0020as\u0020F\nimport\u0020numpy\u0020as\u0020np\nfrom\u0020easydict\u0020import\u0020EasyDict\nfrom\u0020typing\u0020import\u0020Tuple\n\n\nclass\u0020HMCAN(nn.Module):\n """Hierarchical Multimodal Contextual Attention Network\n """\n def\u0020__init__(self, configs, alpha):\n super(HMCAN, self).init()\n self.word_length\u0020=\u0020configs.max_word_length\n self.alpha\u0020=\u0020alpha\n\n self.contextual_transform\u0020=\u0020TextImage_Transformer(\n configs.contextual_transform, configs.contextual_transform.output_dim)\n\n self.contextual_transform2\u0020=\u0020TextImage_Transformer(\n configs.contextual_transform, configs.contextual_transform.output_dim)\n\n\n self.conv\u0020=\u0020nn.Conv2d(2048, 768, 1)\n self.bn\u0020=\u0020nn.BatchNorm2d(768)\n\n self.classifier\u0020=\u0020nn.Sequential(nn.Linear(7686, 256),\n nn.ReLU(True),\n nn.BatchNorm1d(256),\n nn.Linear(256, 2)\n )\n\n\n\n def\u0020forward(self, e, f):\n cap_lengths\u0020=\u0020len(e)\n\n e_f_mask\u0020=\u0020torch.ones(cap_lengths, self.word_length).cuda()\n f_e_mask\u0020=\u0020torch.ones(cap_lengths, 16).cuda()\n\n e\u0020=\u0020torch.squeeze(e, dim=1)\u0020#\u0020[batch_size, 40, 768]\n e1\u0020=\u0020e[:, :self.word_length, :]\n e2\u0020=\u0020e[:, self.word_length: self.word_length2, :]\n e3\u0020=\u0020e[:, self.word_length2:, :]\n #\u0020e\u0020=\u0020self.fc(e)\u0020#\u0020[batch_size, 40, 64]\n\n f\u0020=\u0020F.relu(self.bn(self.conv(f)))\u0020#\u0020[batch_size, 768, 4, 4]\n f\u0020=\u0020f.view(f.shape[0], f.shape[1], -1)\u0020#\u0020[batch_size, 768, 16]\n f\u0020=\u0020f.permute([0, 2, 1])\u0020#\u0020[batch_size, 16, 768]\n\n c1_e1_f\u0020=\u0020self.contextual_transform(e1, e_f_mask, f)\n c1_f_e1\u0020=\u0020self.contextual_transform2(f, f_e_mask, e1)\n a\u0020=\u0020self.alpha\n\n c1\u0020=\u0020a\u0020\u0020c1_e1_f\u0020+\u0020(1\u0020-\u0020a)\u0020*\u0020c1_f_e1\n\n c2_e2_f\u0020=\u0020self.contextual_transform(e2, e_f_mask, f)\n c2_f_e2\u0020=\u0020self.contextual_transform2(f, f_e_mask, e2)\n\n c2\u0020=\u0020a\u0020*\u0020c2_e2_f\u0020+\u0020(1\u0020-\u0020a)\u0020*\u0020c2_f_e2\n\n c3_e3_f\u0020=\u0020self.contextual_transform(e3, e_f_mask, f)\n c3_f_e3\u0020=\u0020self.contextual_transform2(f, f_e_mask, e3)\n\n c3\u0020=\u0020a\u0020*\u0020c3_e3_f\u0020+\u0020(1\u0020-\u0020a)\u0020*\u0020c3_f_e3\n\n x\u0020=\u0020torch.cat((c1, c2, c3), dim=1)\n x\u0020=\u0020self.classifier(x)\n\n\n return\u0020x\n\n\n\nclass\u0020LayerNormalization(nn.Module):\n def\u0020__init__(self, features_count, epsilon=1e-6):\n super().init()\n self.gain\u0020=\u0020nn.Parameter(\n torch.ones(features_count), requires_grad=True)\n self.bias\u0020=\u0020nn.Parameter(\n torch.zeros(features_count), requires_grad=True)\n self.epsilon\u0020=\u0020epsilon\n\n def\u0020forward(self, x):\n mean\u0020=\u0020x.mean(dim=-1, keepdim=True)\n std\u0020=\u0020x.std(dim=-1, keepdim=True)\n return\u0020self.gain\u0020*\u0020(x\u0020-\u0020mean)\u0020/\u0020(std\u0020+\u0020self.epsilon)\u0020+\u0020self.bias\n\n\nclass\u0020TextImage_Transformer(nn.Module):\n def\u0020__init__(self, ct: EasyDict, feature_dim: int):\n super().init()\n\n self.input_norm\u0020=\u0020LayerNormalization(feature_dim)\n input_dim\u0020=\u0020feature_dim\n self.embedding\u0020=\u0020PositionalEncoding(\n input_dim, ct.dropout, max_len=1000)\n\n self.tf\u0020=\u0020TransformerEncoder(\n ct.num_layers, input_dim, ct.num_heads, input_dim,\n ct.dropout)\n\n self.use_context\u0020=\u0020ct.use_context\n if\u0020self.use_context:\n self.tf_context\u0020=\u0020TransformerEncoder(\n ct.atn_ct_num_layers, input_dim, ct.atn_ct_num_heads,\n input_dim, ct.dropout)\n\n init_network(self, 0.01)\n\n def\u0020forward(self, features, mask, hidden_state):\n features\u0020=\u0020self.input_norm(features)\n features\u0020=\u0020self.embedding(features)\n features\u0020=\u0020self.tf(features, features, features, mask)\n add_after_pool\u0020=\u0020None\n if\u0020self.use_context:\n ctx\u0020=\u0020self.tf_context(\n hidden_state, features, features, mask)\n add_after_pool\u0020=\u0020ctx\u0020#\u0020ctx.squeeze(1)\n\n pooled\u0020=\u0020torch.mean(features, dim=1)\n add_after_pool\u0020=\u0020torch.mean(add_after_pool, dim=1)\n if\u0020add_after_pool\u0020is\u0020not\u0020None:\n pooled\u0020=\u0020torch.cat([pooled, add_after_pool], dim=-1)\n return\u0020pooled\n\n\nclass\u0020PositionalEncoding(nn.Module):\n def\u0020__init__(self, dim, dropout_prob=0., max_len=1000):\n super().init()\n pe\u0020=\u0020torch.zeros(max_len, dim).float()\n position\u0020=\u0020torch.arange(0, max_len).unsqueeze(1).float()\n dimension\u0020=\u0020torch.arange(0, dim).float()\n div_term\u0020=\u002010000\u0020**\u0020(2\u0020*\u0020dimension\u0020/\u0020dim)\n pe[:, 0::2]\u0020=\u0020torch.sin(position\u0020/\u0020div_term[0::2])\n pe[:, 1::2]\u0020=\u0020torch.cos(position\u0020/\u0020div_term[1::2])\n self.register_buffer('pe', pe)\n self.dropout\u0020=\u0020nn.Dropout(p=dropout_prob)\n self.dim\u0020=\u0020dim\n\n def\u0020forward(self, x, step=None):\n if\u0020step\u0020is\u0020None:\n x\u0020=\u0020x\u0020+\u0020self.pe[:x.size(1), :]\n else:\n x\u0020=\u0020x\u0020+\u0020self.pe[:, step]\n x\u0020=\u0020self.dropout(x)\n return\u0020x\n\n\nclass\u0020TransformerEncoder(nn.Module):\n def\u0020__init__(self, layers_count, d_model, heads_count, d_ff, dropout_prob):\n super().init()\n self.d_model\u0020=\u0020d_model\n assert\u0020layers_count\u0020>\u00200\n self.encoder_layers\u0020=\u0020nn.ModuleList(\n [TransformerEncoderLayer(\n d_model, heads_count, d_ff, dropout_prob)\n for\u0020_\u0020in\u0020range(layers_count)])\n\n def\u0020forward(self, query, key, value, mask):\n batch_size, query_len, embed_dim\u0020=\u0020query.shape\n batch_size, key_len, embed_dim\u0020=\u0020key.shape\n mask\u0020=\u0020(1\u0020-\u0020mask.unsqueeze(1).expand(batch_size, query_len, key_len))\n mask\u0020=\u0020mask\u0020==\u00201\n sources\u0020=\u0020None\n for\u0020encoder_layer\u0020in\u0020self.encoder_layers:\n sources\u0020=\u0020encoder_layer(query, key, value, mask)\n return\u0020sources\n\n\nclass\u0020TransformerEncoderLayer(nn.Module):\n def\u0020__init__(self, d_model, heads_count, d_ff, dropout_prob):\n super(TransformerEncoderLayer, self).init()\n self.self_attention_layer\u0020=\u0020Sublayer(\n MultiHeadAttention(heads_count, d_model, dropout_prob), d_model)\n self.pointwise_feedforward_layer\u0020=\u0020Sublayer(\n PointwiseFeedForwardNetwork(d_ff, d_model, dropout_prob), d_model)\n self.dropout\u0020=\u0020nn.Dropout(dropout_prob)\n\n\n def\u0020forward(self, query, key, value, sources_mask):\n sources\u0020=\u0020self.self_attention_layer(query, key, value, sources_mask)\n sources\u0020=\u0020self.dropout(sources)\n sources\u0020=\u0020self.pointwise_feedforward_layer(sources)\n return\u0020sources\n\n\nclass\u0020Sublayer(nn.Module):\n def\u0020__init__(self, sublayer, d_model):\n super(Sublayer, self).init()\n self.sublayer\u0020=\u0020sublayer\n self.layer_normalization\u0020=\u0020LayerNormalization(d_model)\n\n def\u0020forward(self, args):\n x\u0020=\u0020args[0]\n x\u0020=\u0020self.sublayer(args)\u0020+\u0020x\n return\u0020self.layer_normalization(x)\n\n\nclass\u0020MultiHeadAttention(nn.Module):\n def\u0020__init__(self, heads_count, d_model, dropout_prob):\n super().init()\n assert\u0020d_model\u0020%\u0020heads_count\u0020==\u00200,\n f"model\u0020dim\u0020{d_model}\u0020not\u0020divisible\u0020by\u0020{heads_count}\u0020heads"\n self.d_head\u0020=\u0020d_model\u0020//\u0020heads_count\n self.heads_count\u0020=\u0020heads_count\n self.query_projection\u0020=\u0020nn.Linear(d_model, heads_count\u0020\u0020self.d_head)\n self.key_projection\u0020=\u0020nn.Linear(d_model, heads_count\u0020\u0020self.d_head)\n self.value_projection\u0020=\u0020nn.Linear(d_model, heads_count\u0020*\u0020self.d_head)\n self.final_projection\u0020=\u0020nn.Linear(d_model, heads_count\u0020*\u0020self.d_head)\n self.dropout\u0020=\u0020nn.Dropout(dropout_prob)\n self.softmax\u0020=\u0020nn.Softmax(dim=3)\n self.attention\u0020=\u0020None\n\n def\u0020forward(self, query, key, value, mask=None):\n batch_size, query_len, d_model\u0020=\u0020query.size()\n d_head\u0020=\u0020d_model\u0020//\u0020self.heads_count\n query_projected\u0020=\u0020self.query_projection(query)\n key_projected\u0020=\u0020self.key_projection(key)\n value_projected\u0020=\u0020self.value_projection(value)\n batch_size, key_len, d_model\u0020=\u0020key_projected.size()\n batch_size, value_len, d_model\u0020=\u0020value_projected.size()\n query_heads\u0020=\u0020query_projected.view(\n batch_size, query_len, self.heads_count, d_head).transpose(1, 2)\n key_heads\u0020=\u0020key_projected.view(\n batch_size, key_len, self.heads_count, d_head).transpose(1, 2)\n value_heads\u0020=\u0020value_projected.view(\n batch_size, value_len, self.heads_count, d_head).transpose(1, 2)\n attention_weights\u0020=\u0020self.scaled_dot_product(\n query_heads, key_heads)\n if\u0020mask\u0020is\u0020not\u0020None:\n mask_expanded\u0020=\u0020mask.unsqueeze(1).expand_as(attention_weights)\n attention_weights\u0020=\u0020attention_weights.masked_fill(\n mask_expanded, -1e18)\n attention\u0020=\u0020self.softmax(attention_weights)\n attention_dropped\u0020=\u0020self.dropout(attention)\n context_heads\u0020=\u0020torch.matmul(\n attention_dropped, value_heads)\n context_sequence\u0020=\u0020context_heads.transpose(1, 2)\n context\u0020=\u0020context_sequence.reshape(\n batch_size, query_len, d_model)\n final_output\u0020=\u0020self.final_projection(context)\n return\u0020final_output\n\n def\u0020scaled_dot_product(self, query_heads, key_heads):\n key_heads_transposed\u0020=\u0020key_heads.transpose(2, 3)\n dot_product\u0020=\u0020torch.matmul(\n query_heads, key_heads_transposed)\n attention_weights\u0020=\u0020dot_product\u0020/\u0020np.sqrt(self.d_head)\n return\u0020attention_weights\n\n\nclass\u0020PointwiseFeedForwardNetwork(nn.Module):\n def\u0020__init__(self, d_ff, d_model, dropout_prob):\n super(PointwiseFeedForwardNetwork, self).init()\n self.feed_forward\u0020=\u0020nn.Sequential(\n nn.Linear(d_model, d_ff),\n nn.Dropout(dropout_prob),\n nn.GELU(),\n nn.Linear(d_ff, d_model),\n nn.Dropout(dropout_prob))\n\n def\u0020forward(self, x):\n return\u0020self.feed_forward(x)\n\n\ndef\u0020truncated_normal_fill(\n shape: Tuple[int], mean: float\u0020=\u00200, std: float\u0020=\u00201,\n limit: float\u0020=\u00202) -> torch.Tensor:\n num_examples\u0020=\u00208\n tmp\u0020=\u0020torch.empty(shape\u0020+\u0020(num_examples,)).normal_()\n valid\u0020=\u0020(tmp\u0020<\u0020limit)\u0020&\u0020(tmp\u0020>\u0020-limit)\n , ind\u0020=\u0020valid.max(-1, keepdim=True)\n return\u0020tmp.gather(-1, ind).squeeze(-1).mul(std).add_(mean)\n\n\ndef\u0020init_weight_(w, init_gain=1):\n\n w.copy_(truncated_normal_fill(w.shape, std=init_gain))\n\n\ndef\u0020init_network(net: nn.Module, init_std: float):\n\n for\u0020key, val\u0020in\u0020net.named_parameters():\n if\u0020"weight"\u0020in\u0020key\u0020or\u0020"bias"\u0020in\u0020key:\n init_weight_(val.data, init_std)

HMCAN: Hierarchical Multimodal Contextual Attention Network for Image-Text Matching

原文地址: https://www.cveoy.top/t/topic/qtAf 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录