import torch import torch.nn as nn import torch.nn.functional as F class PositionwiseFeedForward(nn.Module): def __init__(self, d_in, d_hid, dropout=0.1): super().__init__() self.w_1 = nn.Conv1d(d_in, d_hid, 1) self.w_2 = nn.Conv1d(d_hid, d_in, 1) self.layer_norm = nn.LayerNorm(d_in) self.dropout = nn.Dropout(dropout) def forward(self, x): residual = x output = x.transpose(1, 2) output = self.w_2(F.relu(self.w_1(output))) output = output.transpose(1, 2) output = self.dropout(output) output = self.layer_norm(output + residual) return output class MultiHeadAttention(nn.Module): def __init__(self, hidden_size, num_units, num_heads, dropout_rate): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads assert hidden_size % num_heads == 0 self.linear_q = nn.Linear(hidden_size, num_units) self.linear_k = nn.Linear(hidden_size, num_units) self.linear_v = nn.Linear(hidden_size, num_units) self.dropout = nn.Dropout(dropout_rate) self.softmax = nn.Softmax(dim=-1) def forward(self, queries, keys): """ :param queries: A 3d tensor with shape of [N, T_q, C_q] :param keys: A 3d tensor with shape of [N, T_k, C_k] :return: A 3d tensor with shape of (N, T_q, C) """ Q = self.linear_q(queries) # (N, T_q, C) K = self.linear_k(keys) # (N, T_k, C) V = self.linear_v(keys) # (N, T_k, C) # Split and Concat split_size = self.hidden_size // self.num_heads Q_ = torch.cat(torch.split(Q, split_size, dim=2), dim=0) # (h*N, T_q, C/h) K_ = torch.cat(torch.split(K, split_size, dim=2), dim=0) # (h*N, T_k, C/h) V_ = torch.cat(torch.split(V, split_size, dim=2), dim=0) # (h*N, T_k, C/h) # Multiplication matmul_output = torch.bmm(Q_, K_.transpose(1, 2)) / self.hidden_size ** 0.5 # (h*N, T_q, T_k) # Key Masking key_mask = torch.sign(torch.abs(keys.sum(dim=-1))).repeat(self.num_heads, 1) # (h*N, T_k) key_mask_reshaped = key_mask.unsqueeze(1).repeat(1, queries.shape[1], 1) # (h*N, T_q, T_k) key_paddings = torch.ones_like(matmul_output) * (-2 ** 32 + 1) matmul_output_m1 = torch.where(torch.eq(key_mask_reshaped, 0), key_paddings, matmul_output) # (h*N, T_q, T_k) # Causality - Future Blinding diag_vals = torch.ones_like(matmul_output[0, :, :]) # (T_q, T_k) tril = torch.tril(diag_vals) # (T_q, T_k) causality_mask = tril.unsqueeze(0).repeat(matmul_output.shape[0], 1, 1) # (h*N, T_q, T_k) causality_paddings = torch.ones_like(causality_mask) * (-2 ** 32 + 1) matmul_output_m2 = torch.where(torch.eq(causality_mask, 0), causality_paddings, matmul_output_m1) # (h*N, T_q, T_k) # Activation matmul_output_sm = self.softmax(matmul_output_m2) # (h*N, T_q, T_k) # Query Masking query_mask = torch.sign(torch.abs(queries.sum(dim=-1))).repeat(self.num_heads, 1) # (h*N, T_q) query_mask = query_mask.unsqueeze(-1).repeat(1, 1, keys.shape[1]) # (h*N, T_q, T_k) matmul_output_qm = matmul_output_sm * query_mask # Dropout matmul_output_dropout = self.dropout(matmul_output_qm) # Weighted Sum output_ws = torch.bmm(matmul_output_dropout, V_) # ( h*N, T_q, C/h) # Restore Shape output = torch.cat(torch.split(output_ws, output_ws.shape[0] // self.num_heads, dim=0), dim=2) # (N, T_q, C) # Residual Connection output_res = output + queries return output_res