|
|
|
|
|
import numpy as np |
|
import six |
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
def dropout(input_tensor, dropout_prob): |
|
"""Perform dropout. |
|
|
|
Args: |
|
input_tensor: float Tensor. |
|
dropout_prob: Python float. The probability of dropping out a value (NOT of |
|
*keeping* a dimension as in `tf.nn.dropout`). |
|
|
|
Returns: |
|
A version of `input_tensor` with dropout applied. |
|
""" |
|
if dropout_prob is None or dropout_prob == 0.0: |
|
return input_tensor |
|
|
|
output = nn.Dropout(input_tensor, rate=dropout_prob) |
|
return output |
|
|
|
|
|
def create_look_ahead_mask(seq_length, batch_size=0): |
|
"""Create a look ahead mask given a certain seq length. |
|
|
|
Args: |
|
seq_length: int the length of the sequence. |
|
batch_size: if batch_size if provided, the mask will be repeaded. |
|
|
|
Returns: |
|
the mask ((batch_size), seq_length, seq_length) |
|
""" |
|
mask = 1 - troch.tril(torch.ones((seq_length, seq_length))) |
|
if batch_size > 0: |
|
mask = torch.repeat(torch.unsqueeze(mask, dim=0), batch_size, dim=0) |
|
return mask |
|
|
|
|
|
def create_attention_mask_from_input_mask(from_tensor, to_mask): |
|
"""Create 3D attention mask from a 2D tensor mask. |
|
|
|
Args: |
|
from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. |
|
to_mask: int32 Tensor of shape [batch_size, to_seq_length]. |
|
|
|
Returns: |
|
float Tensor of shape [batch_size, from_seq_length, to_seq_length]. |
|
""" |
|
from_shape = get_shape_list(from_tensor) |
|
batch_size = from_shape[0] |
|
from_seq_length = from_shape[1] |
|
|
|
to_shape = get_shape_list(to_mask) |
|
to_seq_length = to_shape[1] |
|
|
|
to_mask = torch.reshape(to_mask, (batch_size, 1, to_seq_length)).float() |
|
|
|
|
|
|
|
|
|
|
|
|
|
broadcast_ones = torch.ones( |
|
shape=[batch_size, from_seq_length, 1]).float() |
|
|
|
|
|
mask = broadcast_ones * to_mask |
|
|
|
return mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gelu(x): |
|
"""Gaussian Error Linear Unit. |
|
|
|
This is a smoother version of the RELU. |
|
Original paper: https://arxiv.org/abs/1606.08415 |
|
Args: |
|
x: float Tensor to perform activation. |
|
|
|
Returns: |
|
`x` with the GELU activation applied. |
|
""" |
|
cdf = 0.5 * (1.0 + torch.tanh( |
|
(np.sqrt(2 / np.pi) * (x + 0.044715 * torch.pow(x, 3))))) |
|
return x * cdf |
|
|
|
|
|
def get_activation(activation_string): |
|
"""Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`. |
|
|
|
Args: |
|
activation_string: String name of the activation function. |
|
|
|
Returns: |
|
A Python function corresponding to the activation function. If |
|
`activation_string` is None, empty, or "linear", this will return None. |
|
If `activation_string` is not a string, it will return `activation_string`. |
|
|
|
Raises: |
|
ValueError: The `activation_string` does not correspond to a known |
|
activation. |
|
""" |
|
|
|
|
|
|
|
if not isinstance(activation_string, six.string_types): |
|
return activation_string |
|
|
|
if not activation_string: |
|
return None |
|
|
|
act = activation_string.lower() |
|
if act == "linear": |
|
return None |
|
elif act == "relu": |
|
return nn.ReLU |
|
elif act == "gelu": |
|
return gelu |
|
elif act == "tanh": |
|
return torch.tanh |
|
else: |
|
raise ValueError("Unsupported activation: %s" % act) |
|
|
|
|
|
def get_shape_list(tensor): |
|
"""Returns a list of the shape of tensor, preferring static dimensions. |
|
|
|
Args: |
|
tensor: A tf.Tensor object to find the shape of. |
|
|
|
Returns: |
|
A list of dimensions of the shape of tensor. All static dimensions will |
|
be returned as python integers, and dynamic dimensions will be returned |
|
as tf.Tensor scalars. |
|
""" |
|
|
|
shape = tensor.size() |
|
|
|
non_static_indexes = [] |
|
for (index, dim) in enumerate(shape): |
|
if dim is None: |
|
non_static_indexes.append(index) |
|
|
|
if not non_static_indexes: |
|
return shape |
|
else: |
|
print('something wrong with static shaping') |
|
assert False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gather_indexes(sequence_tensor, positions): |
|
"""Gathers the vectors at the specific positions over a minibatch.""" |
|
sequence_shape = get_shape_list(sequence_tensor) |
|
batch_size = sequence_shape[0] |
|
seq_length = sequence_shape[1] |
|
width = sequence_shape[2] |
|
|
|
flat_offsets = torch.reshape( |
|
torch.range(0, batch_size).int() * seq_length, (-1, 1)) |
|
flat_positions = torch.reshape(positions + flat_offsets, (-1)) |
|
flat_sequence_tensor = torch.reshape(sequence_tensor, |
|
(batch_size * seq_length, width)) |
|
output_tensor = torch.gather(flat_sequence_tensor, flat_positions) |
|
output_tensor = torch.reshape(output_tensor, (batch_size, -1, width)) |
|
return output_tensor |
|
|
|
|
|
def split_heads(x, batch_size, seq_length, num_joints, num_attention_heads, |
|
model_depth): |
|
"""Split the embedding vector for different heads for the spatial attention. |
|
|
|
Args: |
|
x: the embedding vector (batch_size, seq_len, num_joints, model_depth) or |
|
(batch_size, seq_len, model_depth) |
|
batch_size: the batch_size |
|
seq_length: the sequence length |
|
num_joints: the number of joints |
|
num_attention_heads: the number of attention heads |
|
model_depth: the model depth |
|
|
|
Returns: |
|
the split vector (batch_size, seq_len, num_heads, num_joints, depth) or |
|
(batch_size, num_heads, seq_len, depth) |
|
""" |
|
depth = model_depth // num_attention_heads |
|
if len(x.get_shape().as_list()) == 4: |
|
|
|
x = torch.reshape( |
|
x, (batch_size, seq_length, num_joints, num_attention_heads, depth)) |
|
return x.permute(0, 1, 3, 2, 4) |
|
elif len(x.get_shape().as_list()) == 3: |
|
|
|
x = torch.reshape(x, (batch_size, seq_length, num_attention_heads, depth)) |
|
return x.permute(0, 2, 1, 3) |
|
else: |
|
raise ValueError("Unsupported input tensor dimension.") |
|
|
|
|
|
def scaled_dot_product_attention(q, k, v, mask): |
|
"""The scaled dot product attention mechanism. |
|
|
|
Attn(Q, K, V) = softmax((QK^T+mask)/sqrt(depth))V. |
|
|
|
Args: |
|
q: the query vectors matrix (..., attn_dim, d_model/num_heads) |
|
k: the key vector matrix (..., attn_dim, d_model/num_heads) |
|
v: the value vector matrix (..., attn_dim, d_model/num_heads) |
|
mask: a mask for attention |
|
|
|
Returns: |
|
the updated encoding and the attention weights matrix |
|
""" |
|
|
|
|
|
matmul_qk = q @ k.transpose() |
|
|
|
|
|
dk = torch.shape(k)[-1].float() |
|
scaled_attention_logits = matmul_qk / torch.sqrt(dk) |
|
|
|
|
|
if mask is not None: |
|
scaled_attention_logits += (mask * -1e9) |
|
|
|
|
|
attention_weights = nn.softmax( |
|
scaled_attention_logits, dim=-1) |
|
|
|
output = attention_weights @ v |
|
|
|
return output, attention_weights |
|
|