File size: 7,096 Bytes
6775edf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss,CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
import models.configs as configs
import math
class Attention(nn.Module):
def __init__(self, config, vis, mm=True):
super(Attention, self).__init__()
self.vis = vis
self.num_attention_heads = config.transformer["num_heads"]
self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = Linear(config.hidden_size, self.all_head_size)
self.key = Linear(config.hidden_size, self.all_head_size)
self.value = Linear(config.hidden_size, self.all_head_size)
if mm:
self.query_text = Linear(config.hidden_size, self.all_head_size)
self.key_text = Linear(config.hidden_size, self.all_head_size)
self.value_text = Linear(config.hidden_size, self.all_head_size)
self.out_text = Linear(config.hidden_size, config.hidden_size)
self.attn_dropout_text = Dropout(config.transformer["attention_dropout_rate"])
self.attn_dropout_it = Dropout(config.transformer["attention_dropout_rate"])
self.attn_dropout_ti = Dropout(config.transformer["attention_dropout_rate"])
self.proj_dropout_text = Dropout(config.transformer["attention_dropout_rate"])
self.out = Linear(config.hidden_size, config.hidden_size)
self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
self.softmax = Softmax(dim=-1)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, text=None):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
if text is not None:
text_q = self.query_text(text)
text_k = self.key_text(text)
text_v = self.value_text(text)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
if text is not None:
query_layer_img = query_layer
key_layer_img = key_layer
value_layer_img = value_layer
query_layer_text = self.transpose_for_scores(text_q)
key_layer_text = self.transpose_for_scores(text_k)
value_layer_text = self.transpose_for_scores(text_v)
if text is None:
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
weights = attention_probs if self.vis else None
attention_probs = self.attn_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
attention_output = self.out(context_layer)
attention_output = self.proj_dropout(attention_output)
return attention_output,None, weights
else:
attention_scores_img = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores_text = torch.matmul(query_layer_text, key_layer_text.transpose(-1, -2))
attention_scores_it = torch.matmul(query_layer_img, key_layer_text.transpose(-1, -2))
attention_scores_ti = torch.matmul(query_layer_text, key_layer_img.transpose(-1, -2))
attention_scores_img = attention_scores_img / math.sqrt(self.attention_head_size)
attention_probs_img = self.softmax(attention_scores_img)
weights_img = attention_probs_img if self.vis else None
attention_probs_img = self.attn_dropout(attention_probs_img)
attention_scores_text = attention_scores_text / math.sqrt(self.attention_head_size)
attention_probs_text = self.softmax(attention_scores_text)
text_per_weights = attention_probs_text.mean(dim=-1)
text_per_weights = self.softmax(text_per_weights)
weights_text = attention_probs_text if self.vis else None
attention_probs_text = self.attn_dropout_text(attention_probs_text)
attention_scores_it = attention_scores_it / math.sqrt(self.attention_head_size)
attention_probs_it = self.softmax(attention_scores_it)
attention_probs_it = self.attn_dropout_it(attention_probs_it)
attention_scores_ti = attention_scores_ti / math.sqrt(self.attention_head_size)
attention_probs_ti = self.softmax(attention_scores_ti)
attention_probs_ti = self.attn_dropout_ti(attention_probs_ti)
context_layer_img = torch.matmul(attention_probs_img, value_layer_img)
context_layer_img = context_layer_img.permute(0, 2, 1, 3).contiguous()
context_layer_text = torch.matmul(attention_probs_text, value_layer_text)
context_layer_text = context_layer_text.permute(0, 2, 1, 3).contiguous()
context_layer_it = torch.matmul(attention_probs_it, value_layer_text)
context_layer_it = context_layer_it.permute(0, 2, 1, 3).contiguous()
context_layer_ti = torch.matmul(attention_probs_ti, value_layer_img)
context_layer_ti = context_layer_ti.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer_img.size()[:-2] + (self.all_head_size,)
context_layer_img = context_layer_img.view(*new_context_layer_shape)
new_context_layer_shape = context_layer_text.size()[:-2] + (self.all_head_size,)
context_layer_text = context_layer_text.view(*new_context_layer_shape)
new_context_layer_shape = context_layer_it.size()[:-2] + (self.all_head_size,)
context_layer_it = context_layer_it.view(*new_context_layer_shape)
new_context_layer_shape = context_layer_ti.size()[:-2] + (self.all_head_size,)
context_layer_ti = context_layer_ti.view(*new_context_layer_shape)
attention_output_img = self.out((context_layer_img + context_layer_it)/2)
attention_output_text = self.out((context_layer_text + context_layer_ti)/2)
attention_output_img = self.proj_dropout(attention_output_img)
attention_output_text = self.proj_dropout_text(attention_output_text)
return attention_output_img, attention_output_text |