|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch import Tensor |
|
import numpy as np |
|
from typing import Optional, Tuple |
|
|
|
|
|
class ScaledDotProductAttention(nn.Module): |
|
""" |
|
Scaled Dot-Product Attention proposed in "Attention Is All You Need" |
|
Compute the dot products of the query with all keys, divide each by sqrt(dim), |
|
and apply a softmax function to obtain the weights on the values |
|
Args: dim, mask |
|
dim (int): dimention of attention |
|
mask (torch.Tensor): tensor containing indices to be masked |
|
Inputs: query, key, value, mask |
|
- **query** (batch, q_len, d_model): tensor containing projection vector for decoder. |
|
- **key** (batch, k_len, d_model): tensor containing projection vector for encoder. |
|
- **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence. |
|
- **mask** (-): tensor containing indices to be masked |
|
Returns: context, attn |
|
- **context**: tensor containing the context vector from attention mechanism. |
|
- **attn**: tensor containing the attention (alignment) from the encoder outputs. |
|
""" |
|
def __init__(self, dim: int): |
|
super(ScaledDotProductAttention, self).__init__() |
|
self.sqrt_dim = np.sqrt(dim) |
|
|
|
def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: |
|
score = torch.bmm(query, key.transpose(1, 2)) / self.sqrt_dim |
|
|
|
if mask is not None: |
|
score.masked_fill_(mask.view(score.size()), -float('Inf')) |
|
|
|
attn = F.softmax(score, -1) |
|
context = torch.bmm(attn, value) |
|
return context, attn |
|
|
|
|
|
class DotProductAttention(nn.Module): |
|
""" |
|
Compute the dot products of the query with all values and apply a softmax function to obtain the weights on the values |
|
""" |
|
def __init__(self, hidden_dim): |
|
super(DotProductAttention, self).__init__() |
|
|
|
def forward(self, query: Tensor, value: Tensor) -> Tuple[Tensor, Tensor]: |
|
batch_size, hidden_dim, input_size = query.size(0), query.size(2), value.size(1) |
|
|
|
score = torch.bmm(query, value.transpose(1, 2)) |
|
attn = F.softmax(score.view(-1, input_size), dim=1).view(batch_size, -1, input_size) |
|
context = torch.bmm(attn, value) |
|
|
|
return context, attn |
|
|
|
|
|
class AdditiveAttention(nn.Module): |
|
""" |
|
Applies a additive attention (bahdanau) mechanism on the output features from the decoder. |
|
Additive attention proposed in "Neural Machine Translation by Jointly Learning to Align and Translate" paper. |
|
Args: |
|
hidden_dim (int): dimesion of hidden state vector |
|
Inputs: query, value |
|
- **query** (batch_size, q_len, hidden_dim): tensor containing the output features from the decoder. |
|
- **value** (batch_size, v_len, hidden_dim): tensor containing features of the encoded input sequence. |
|
Returns: context, attn |
|
- **context**: tensor containing the context vector from attention mechanism. |
|
- **attn**: tensor containing the alignment from the encoder outputs. |
|
Reference: |
|
- **Neural Machine Translation by Jointly Learning to Align and Translate**: https://arxiv.org/abs/1409.0473 |
|
""" |
|
def __init__(self, hidden_dim: int) -> None: |
|
super(AdditiveAttention, self).__init__() |
|
self.query_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) |
|
self.key_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) |
|
self.bias = nn.Parameter(torch.rand(hidden_dim).uniform_(-0.1, 0.1)) |
|
self.score_proj = nn.Linear(hidden_dim, 1) |
|
|
|
def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tuple[Tensor, Tensor]: |
|
score = self.score_proj(torch.tanh(self.key_proj(key) + self.query_proj(query) + self.bias)).squeeze(-1) |
|
attn = F.softmax(score, dim=-1) |
|
context = torch.bmm(attn.unsqueeze(1), value) |
|
return context, attn |
|
|
|
|
|
class LocationAwareAttention(nn.Module): |
|
""" |
|
Applies a location-aware attention mechanism on the output features from the decoder. |
|
Location-aware attention proposed in "Attention-Based Models for Speech Recognition" paper. |
|
The location-aware attention mechanism is performing well in speech recognition tasks. |
|
We refer to implementation of ClovaCall Attention style. |
|
Args: |
|
hidden_dim (int): dimesion of hidden state vector |
|
smoothing (bool): flag indication whether to use smoothing or not. |
|
Inputs: query, value, last_attn, smoothing |
|
- **query** (batch, q_len, hidden_dim): tensor containing the output features from the decoder. |
|
- **value** (batch, v_len, hidden_dim): tensor containing features of the encoded input sequence. |
|
- **last_attn** (batch_size * num_heads, v_len): tensor containing previous timestep`s attention (alignment) |
|
Returns: output, attn |
|
- **output** (batch, output_len, dimensions): tensor containing the feature from encoder outputs |
|
- **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs. |
|
Reference: |
|
- **Attention-Based Models for Speech Recognition**: https://arxiv.org/abs/1506.07503 |
|
- **ClovaCall**: https://github.com/clovaai/ClovaCall/blob/master/las.pytorch/models/attention.py |
|
""" |
|
def __init__(self, hidden_dim: int, smoothing: bool = True) -> None: |
|
super(LocationAwareAttention, self).__init__() |
|
self.hidden_dim = hidden_dim |
|
self.conv1d = nn.Conv1d(in_channels=1, out_channels=hidden_dim, kernel_size=3, padding=1) |
|
self.query_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) |
|
self.value_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) |
|
self.score_proj = nn.Linear(hidden_dim, 1, bias=True) |
|
self.bias = nn.Parameter(torch.rand(hidden_dim).uniform_(-0.1, 0.1)) |
|
self.smoothing = smoothing |
|
|
|
def forward(self, query: Tensor, value: Tensor, last_attn: Tensor) -> Tuple[Tensor, Tensor]: |
|
batch_size, hidden_dim, seq_len = query.size(0), query.size(2), value.size(1) |
|
|
|
|
|
if last_attn is None: |
|
last_attn = value.new_zeros(batch_size, seq_len) |
|
|
|
conv_attn = torch.transpose(self.conv1d(last_attn.unsqueeze(1)), 1, 2) |
|
score = self.score_proj(torch.tanh( |
|
self.query_proj(query.reshape(-1, hidden_dim)).view(batch_size, -1, hidden_dim) |
|
+ self.value_proj(value.reshape(-1, hidden_dim)).view(batch_size, -1, hidden_dim) |
|
+ conv_attn |
|
+ self.bias |
|
)).squeeze(dim=-1) |
|
|
|
if self.smoothing: |
|
score = torch.sigmoid(score) |
|
attn = torch.div(score, score.sum(dim=-1).unsqueeze(dim=-1)) |
|
else: |
|
attn = F.softmax(score, dim=-1) |
|
|
|
context = torch.bmm(attn.unsqueeze(dim=1), value).squeeze(dim=1) |
|
|
|
return context, attn |
|
|
|
|
|
class MultiHeadLocationAwareAttention(nn.Module): |
|
""" |
|
Applies a multi-headed location-aware attention mechanism on the output features from the decoder. |
|
Location-aware attention proposed in "Attention-Based Models for Speech Recognition" paper. |
|
The location-aware attention mechanism is performing well in speech recognition tasks. |
|
In the above paper applied a signle head, but we applied multi head concept. |
|
Args: |
|
hidden_dim (int): The number of expected features in the output |
|
num_heads (int): The number of heads. (default: ) |
|
conv_out_channel (int): The number of out channel in convolution |
|
Inputs: query, value, prev_attn |
|
- **query** (batch, q_len, hidden_dim): tensor containing the output features from the decoder. |
|
- **value** (batch, v_len, hidden_dim): tensor containing features of the encoded input sequence. |
|
- **prev_attn** (batch_size * num_heads, v_len): tensor containing previous timestep`s attention (alignment) |
|
Returns: output, attn |
|
- **output** (batch, output_len, dimensions): tensor containing the feature from encoder outputs |
|
- **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs. |
|
Reference: |
|
- **Attention Is All You Need**: https://arxiv.org/abs/1706.03762 |
|
- **Attention-Based Models for Speech Recognition**: https://arxiv.org/abs/1506.07503 |
|
""" |
|
def __init__(self, hidden_dim: int, num_heads: int = 8, conv_out_channel: int = 10) -> None: |
|
super(MultiHeadLocationAwareAttention, self).__init__() |
|
self.hidden_dim = hidden_dim |
|
self.num_heads = num_heads |
|
self.dim = int(hidden_dim / num_heads) |
|
self.conv1d = nn.Conv1d(num_heads, conv_out_channel, kernel_size=3, padding=1) |
|
self.loc_proj = nn.Linear(conv_out_channel, self.dim, bias=False) |
|
self.query_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=False) |
|
self.value_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=False) |
|
self.score_proj = nn.Linear(self.dim, 1, bias=True) |
|
self.bias = nn.Parameter(torch.rand(self.dim).uniform_(-0.1, 0.1)) |
|
|
|
def forward(self, query: Tensor, value: Tensor, last_attn: Tensor) -> Tuple[Tensor, Tensor]: |
|
batch_size, seq_len = value.size(0), value.size(1) |
|
|
|
if last_attn is None: |
|
last_attn = value.new_zeros(batch_size, self.num_heads, seq_len) |
|
|
|
loc_energy = torch.tanh(self.loc_proj(self.conv1d(last_attn).transpose(1, 2))) |
|
loc_energy = loc_energy.unsqueeze(1).repeat(1, self.num_heads, 1, 1).view(-1, seq_len, self.dim) |
|
|
|
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.dim).permute(0, 2, 1, 3) |
|
value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.dim).permute(0, 2, 1, 3) |
|
query = query.contiguous().view(-1, 1, self.dim) |
|
value = value.contiguous().view(-1, seq_len, self.dim) |
|
|
|
score = self.score_proj(torch.tanh(value + query + loc_energy + self.bias)).squeeze(2) |
|
attn = F.softmax(score, dim=1) |
|
|
|
value = value.view(batch_size, seq_len, self.num_heads, self.dim).permute(0, 2, 1, 3) |
|
value = value.contiguous().view(-1, seq_len, self.dim) |
|
|
|
context = torch.bmm(attn.unsqueeze(1), value).view(batch_size, -1, self.num_heads * self.dim) |
|
attn = attn.view(batch_size, self.num_heads, -1) |
|
|
|
return context, attn |
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
""" |
|
Multi-Head Attention proposed in "Attention Is All You Need" |
|
Instead of performing a single attention function with d_model-dimensional keys, values, and queries, |
|
project the queries, keys and values h times with different, learned linear projections to d_head dimensions. |
|
These are concatenated and once again projected, resulting in the final values. |
|
Multi-head attention allows the model to jointly attend to information from different representation |
|
subspaces at different positions. |
|
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_o |
|
where head_i = Attention(Q · W_q, K · W_k, V · W_v) |
|
Args: |
|
d_model (int): The dimension of keys / values / quries (default: 512) |
|
num_heads (int): The number of attention heads. (default: 8) |
|
Inputs: query, key, value, mask |
|
- **query** (batch, q_len, d_model): In transformer, three different ways: |
|
Case 1: come from previoys decoder layer |
|
Case 2: come from the input embedding |
|
Case 3: come from the output embedding (masked) |
|
- **key** (batch, k_len, d_model): In transformer, three different ways: |
|
Case 1: come from the output of the encoder |
|
Case 2: come from the input embeddings |
|
Case 3: come from the output embedding (masked) |
|
- **value** (batch, v_len, d_model): In transformer, three different ways: |
|
Case 1: come from the output of the encoder |
|
Case 2: come from the input embeddings |
|
Case 3: come from the output embedding (masked) |
|
- **mask** (-): tensor containing indices to be masked |
|
Returns: output, attn |
|
- **output** (batch, output_len, dimensions): tensor containing the attended output features. |
|
- **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs. |
|
""" |
|
def __init__(self, d_model: int = 512, num_heads: int = 8): |
|
super(MultiHeadAttention, self).__init__() |
|
|
|
assert d_model % num_heads == 0, "d_model % num_heads should be zero." |
|
|
|
self.d_head = int(d_model / num_heads) |
|
self.num_heads = num_heads |
|
self.scaled_dot_attn = ScaledDotProductAttention(self.d_head) |
|
self.query_proj = nn.Linear(d_model, self.d_head * num_heads) |
|
self.key_proj = nn.Linear(d_model, self.d_head * num_heads) |
|
self.value_proj = nn.Linear(d_model, self.d_head * num_heads) |
|
|
|
def forward( |
|
self, |
|
query: Tensor, |
|
key: Tensor, |
|
value: Tensor, |
|
mask: Optional[Tensor] = None |
|
) -> Tuple[Tensor, Tensor]: |
|
batch_size = value.size(0) |
|
|
|
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) |
|
key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head) |
|
value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head) |
|
|
|
query = query.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) |
|
key = key.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) |
|
value = value.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) |
|
|
|
if mask is not None: |
|
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) |
|
|
|
context, attn = self.scaled_dot_attn(query, key, value, mask) |
|
|
|
context = context.view(self.num_heads, batch_size, -1, self.d_head) |
|
context = context.permute(1, 2, 0, 3).contiguous().view(batch_size, -1, self.num_heads * self.d_head) |
|
|
|
return context, attn |
|
|
|
|
|
class RelativeMultiHeadAttention(nn.Module): |
|
""" |
|
Multi-head attention with relative positional encoding. |
|
This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" |
|
Args: |
|
d_model (int): The dimension of model |
|
num_heads (int): The number of attention heads. |
|
dropout_p (float): probability of dropout |
|
Inputs: query, key, value, pos_embedding, mask |
|
- **query** (batch, time, dim): Tensor containing query vector |
|
- **key** (batch, time, dim): Tensor containing key vector |
|
- **value** (batch, time, dim): Tensor containing value vector |
|
- **pos_embedding** (batch, time, dim): Positional embedding tensor |
|
- **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked |
|
Returns: |
|
- **outputs**: Tensor produces by relative multi head attention module. |
|
""" |
|
def __init__( |
|
self, |
|
d_model: int = 512, |
|
num_heads: int = 16, |
|
dropout_p: float = 0.1, |
|
): |
|
super(RelativeMultiHeadAttention, self).__init__() |
|
assert d_model % num_heads == 0, "d_model % num_heads should be zero." |
|
self.d_model = d_model |
|
self.d_head = int(d_model / num_heads) |
|
self.num_heads = num_heads |
|
self.sqrt_dim = math.sqrt(d_model) |
|
|
|
self.query_proj = nn.Linear(d_model, d_model) |
|
self.key_proj = nn.Linear(d_model, d_model) |
|
self.value_proj = nn.Linear(d_model, d_model) |
|
self.pos_proj = nn.Linear(d_model, d_model, bias=False) |
|
|
|
self.dropout = nn.Dropout(p=dropout_p) |
|
self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) |
|
self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) |
|
torch.nn.init.xavier_uniform_(self.u_bias) |
|
torch.nn.init.xavier_uniform_(self.v_bias) |
|
|
|
self.out_proj = nn.Linear(d_model, d_model) |
|
|
|
def forward( |
|
self, |
|
query: Tensor, |
|
key: Tensor, |
|
value: Tensor, |
|
pos_embedding: Tensor, |
|
mask: Optional[Tensor] = None, |
|
) -> Tensor: |
|
batch_size = value.size(0) |
|
|
|
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) |
|
key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3) |
|
value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3) |
|
pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head) |
|
|
|
content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3)) |
|
pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1)) |
|
pos_score = self._compute_relative_positional_encoding(pos_score) |
|
|
|
score = (content_score + pos_score) / self.sqrt_dim |
|
|
|
if mask is not None: |
|
mask = mask.unsqueeze(1) |
|
score.masked_fill_(mask, -1e9) |
|
|
|
attn = F.softmax(score, -1) |
|
attn = self.dropout(attn) |
|
|
|
context = torch.matmul(attn, value).transpose(1, 2) |
|
context = context.contiguous().view(batch_size, -1, self.d_model) |
|
|
|
return self.out_proj(context) |
|
|
|
def _compute_relative_positional_encoding(self, pos_score: Tensor) -> Tensor: |
|
batch_size, num_heads, seq_length1, seq_length2 = pos_score.size() |
|
zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1) |
|
padded_pos_score = torch.cat([zeros, pos_score], dim=-1) |
|
|
|
padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1) |
|
pos_score = padded_pos_score[:, :, 1:].view_as(pos_score) |
|
|
|
return pos_score |
|
|
|
|
|
class CustomizingAttention(nn.Module): |
|
r""" |
|
Customizing Attention |
|
Applies a multi-head + location-aware attention mechanism on the output features from the decoder. |
|
Multi-head attention proposed in "Attention Is All You Need" paper. |
|
Location-aware attention proposed in "Attention-Based Models for Speech Recognition" paper. |
|
I combined these two attention mechanisms as custom. |
|
Args: |
|
hidden_dim (int): The number of expected features in the output |
|
num_heads (int): The number of heads. (default: ) |
|
conv_out_channel (int): The dimension of convolution |
|
Inputs: query, value, last_attn |
|
- **query** (batch, q_len, hidden_dim): tensor containing the output features from the decoder. |
|
- **value** (batch, v_len, hidden_dim): tensor containing features of the encoded input sequence. |
|
- **last_attn** (batch_size * num_heads, v_len): tensor containing previous timestep`s alignment |
|
Returns: output, attn |
|
- **output** (batch, output_len, dimensions): tensor containing the attended output features from the decoder. |
|
- **attn** (batch * num_heads, v_len): tensor containing the alignment from the encoder outputs. |
|
Reference: |
|
- **Attention Is All You Need**: https://arxiv.org/abs/1706.03762 |
|
- **Attention-Based Models for Speech Recognition**: https://arxiv.org/abs/1506.07503 |
|
""" |
|
|
|
def __init__(self, hidden_dim: int, num_heads: int = 4, conv_out_channel: int = 10) -> None: |
|
super(CustomizingAttention, self).__init__() |
|
self.hidden_dim = hidden_dim |
|
self.num_heads = num_heads |
|
self.dim = int(hidden_dim / num_heads) |
|
self.scaled_dot_attn = ScaledDotProductAttention(self.dim) |
|
self.conv1d = nn.Conv1d(1, conv_out_channel, kernel_size=3, padding=1) |
|
self.query_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=True) |
|
self.value_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=False) |
|
self.loc_proj = nn.Linear(conv_out_channel, self.dim, bias=False) |
|
self.bias = nn.Parameter(torch.rand(self.dim * num_heads).uniform_(-0.1, 0.1)) |
|
|
|
def forward(self, query: Tensor, value: Tensor, last_attn: Tensor) -> Tuple[Tensor, Tensor]: |
|
batch_size, q_len, v_len = value.size(0), query.size(1), value.size(1) |
|
|
|
if last_attn is None: |
|
last_attn = value.new_zeros(batch_size * self.num_heads, v_len) |
|
|
|
loc_energy = self.get_loc_energy(last_attn, batch_size, v_len) |
|
|
|
query = self.query_proj(query).view(batch_size, q_len, self.num_heads * self.dim) |
|
value = self.value_proj(value).view(batch_size, v_len, self.num_heads * self.dim) + loc_energy + self.bias |
|
|
|
query = query.view(batch_size, q_len, self.num_heads, self.dim).permute(2, 0, 1, 3) |
|
value = value.view(batch_size, v_len, self.num_heads, self.dim).permute(2, 0, 1, 3) |
|
query = query.contiguous().view(-1, q_len, self.dim) |
|
value = value.contiguous().view(-1, v_len, self.dim) |
|
|
|
context, attn = self.scaled_dot_attn(query, value) |
|
attn = attn.squeeze() |
|
|
|
context = context.view(self.num_heads, batch_size, q_len, self.dim).permute(1, 2, 0, 3) |
|
context = context.contiguous().view(batch_size, q_len, -1) |
|
|
|
return context, attn |
|
|
|
def get_loc_energy(self, last_attn: Tensor, batch_size: int, v_len: int) -> Tensor: |
|
conv_feat = self.conv1d(last_attn.unsqueeze(1)) |
|
conv_feat = conv_feat.view(batch_size, self.num_heads, -1, v_len).permute(0, 1, 3, 2) |
|
|
|
loc_energy = self.loc_proj(conv_feat).view(batch_size, self.num_heads, v_len, self.dim) |
|
loc_energy = loc_energy.permute(0, 2, 1, 3).reshape(batch_size, v_len, self.num_heads * self.dim) |
|
|
|
return loc_energy |