|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
class PositionwiseFeedForward(nn.Module): |
|
def __init__(self, d_in, d_hid, dropout=0.1): |
|
super().__init__() |
|
self.w_1 = nn.Conv1d(d_in, d_hid, 1) |
|
self.w_2 = nn.Conv1d(d_hid, d_in, 1) |
|
self.layer_norm = nn.LayerNorm(d_in) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, x): |
|
residual = x |
|
output = x.transpose(1, 2) |
|
output = self.w_2(F.relu(self.w_1(output))) |
|
output = output.transpose(1, 2) |
|
output = self.dropout(output) |
|
output = self.layer_norm(output + residual) |
|
return output |
|
|
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
def __init__(self, hidden_size, num_units, num_heads, dropout_rate): |
|
super().__init__() |
|
self.hidden_size = hidden_size |
|
self.num_heads = num_heads |
|
assert hidden_size % num_heads == 0 |
|
|
|
self.linear_q = nn.Linear(hidden_size, num_units) |
|
self.linear_k = nn.Linear(hidden_size, num_units) |
|
self.linear_v = nn.Linear(hidden_size, num_units) |
|
self.dropout = nn.Dropout(dropout_rate) |
|
self.softmax = nn.Softmax(dim=-1) |
|
|
|
|
|
def forward(self, queries, keys): |
|
""" |
|
:param queries: A 3d tensor with shape of [N, T_q, C_q] |
|
:param keys: A 3d tensor with shape of [N, T_k, C_k] |
|
|
|
:return: A 3d tensor with shape of (N, T_q, C) |
|
|
|
""" |
|
Q = self.linear_q(queries) |
|
K = self.linear_k(keys) |
|
V = self.linear_v(keys) |
|
|
|
|
|
split_size = self.hidden_size // self.num_heads |
|
Q_ = torch.cat(torch.split(Q, split_size, dim=2), dim=0) |
|
K_ = torch.cat(torch.split(K, split_size, dim=2), dim=0) |
|
V_ = torch.cat(torch.split(V, split_size, dim=2), dim=0) |
|
|
|
|
|
matmul_output = torch.bmm(Q_, K_.transpose(1, 2)) / self.hidden_size ** 0.5 |
|
|
|
|
|
key_mask = torch.sign(torch.abs(keys.sum(dim=-1))).repeat(self.num_heads, 1) |
|
key_mask_reshaped = key_mask.unsqueeze(1).repeat(1, queries.shape[1], 1) |
|
key_paddings = torch.ones_like(matmul_output) * (-2 ** 32 + 1) |
|
matmul_output_m1 = torch.where(torch.eq(key_mask_reshaped, 0), key_paddings, matmul_output) |
|
|
|
|
|
diag_vals = torch.ones_like(matmul_output[0, :, :]) |
|
tril = torch.tril(diag_vals) |
|
causality_mask = tril.unsqueeze(0).repeat(matmul_output.shape[0], 1, 1) |
|
causality_paddings = torch.ones_like(causality_mask) * (-2 ** 32 + 1) |
|
matmul_output_m2 = torch.where(torch.eq(causality_mask, 0), causality_paddings, matmul_output_m1) |
|
|
|
|
|
matmul_output_sm = self.softmax(matmul_output_m2) |
|
|
|
|
|
query_mask = torch.sign(torch.abs(queries.sum(dim=-1))).repeat(self.num_heads, 1) |
|
query_mask = query_mask.unsqueeze(-1).repeat(1, 1, keys.shape[1]) |
|
matmul_output_qm = matmul_output_sm * query_mask |
|
|
|
|
|
matmul_output_dropout = self.dropout(matmul_output_qm) |
|
|
|
|
|
output_ws = torch.bmm(matmul_output_dropout, V_) |
|
|
|
|
|
output = torch.cat(torch.split(output_ws, output_ws.shape[0] // self.num_heads, dim=0), dim=2) |
|
|
|
|
|
output_res = output + queries |
|
|
|
return output_res |
|
|