MotionGPT / mGPT /archs /tools /embeddings.py
bill-jiang's picture
Init
4409449
raw
history blame
9.71 kB
# This file is taken from signjoey repository
import math
import torch
from torch import Tensor, nn
def get_activation(activation_type):
if activation_type == "relu":
return nn.ReLU()
elif activation_type == "relu6":
return nn.ReLU6()
elif activation_type == "prelu":
return nn.PReLU()
elif activation_type == "selu":
return nn.SELU()
elif activation_type == "celu":
return nn.CELU()
elif activation_type == "gelu":
return nn.GELU()
elif activation_type == "sigmoid":
return nn.Sigmoid()
elif activation_type == "softplus":
return nn.Softplus()
elif activation_type == "softshrink":
return nn.Softshrink()
elif activation_type == "softsign":
return nn.Softsign()
elif activation_type == "tanh":
return nn.Tanh()
elif activation_type == "tanhshrink":
return nn.Tanhshrink()
else:
raise ValueError("Unknown activation type {}".format(activation_type))
class MaskedNorm(nn.Module):
"""
Original Code from:
https://discuss.pytorch.org/t/batchnorm-for-different-sized-samples-in-batch/44251/8
"""
def __init__(self, norm_type, num_groups, num_features):
super().__init__()
self.norm_type = norm_type
if self.norm_type == "batch":
self.norm = nn.BatchNorm1d(num_features=num_features)
elif self.norm_type == "group":
self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=num_features)
elif self.norm_type == "layer":
self.norm = nn.LayerNorm(normalized_shape=num_features)
else:
raise ValueError("Unsupported Normalization Layer")
self.num_features = num_features
def forward(self, x: Tensor, mask: Tensor):
if self.training:
reshaped = x.reshape([-1, self.num_features])
reshaped_mask = mask.reshape([-1, 1]) > 0
selected = torch.masked_select(reshaped, reshaped_mask).reshape(
[-1, self.num_features]
)
batch_normed = self.norm(selected)
scattered = reshaped.masked_scatter(reshaped_mask, batch_normed)
return scattered.reshape([x.shape[0], -1, self.num_features])
else:
reshaped = x.reshape([-1, self.num_features])
batched_normed = self.norm(reshaped)
return batched_normed.reshape([x.shape[0], -1, self.num_features])
# TODO (Cihan): Spatial and Word Embeddings are pretty much the same
# We might as well convert them into a single module class.
# Only difference is the lut vs linear layers.
class Embeddings(nn.Module):
"""
Simple embeddings class
"""
# pylint: disable=unused-argument
def __init__(
self,
embedding_dim: int = 64,
num_heads: int = 8,
scale: bool = False,
scale_factor: float = None,
norm_type: str = None,
activation_type: str = None,
vocab_size: int = 0,
padding_idx: int = 1,
freeze: bool = False,
**kwargs
):
"""
Create new embeddings for the vocabulary.
Use scaling for the Transformer.
:param embedding_dim:
:param scale:
:param vocab_size:
:param padding_idx:
:param freeze: freeze the embeddings during training
"""
super().__init__()
self.embedding_dim = embedding_dim
self.vocab_size = vocab_size
self.lut = nn.Embedding(vocab_size, self.embedding_dim, padding_idx=padding_idx)
self.norm_type = norm_type
if self.norm_type:
self.norm = MaskedNorm(
norm_type=norm_type, num_groups=num_heads, num_features=embedding_dim
)
self.activation_type = activation_type
if self.activation_type:
self.activation = get_activation(activation_type)
self.scale = scale
if self.scale:
if scale_factor:
self.scale_factor = scale_factor
else:
self.scale_factor = math.sqrt(self.embedding_dim)
if freeze:
freeze_params(self)
# pylint: disable=arguments-differ
def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
"""
Perform lookup for input `x` in the embedding table.
:param mask: token masks
:param x: index in the vocabulary
:return: embedded representation for `x`
"""
x = self.lut(x)
if self.norm_type:
x = self.norm(x, mask)
if self.activation_type:
x = self.activation(x)
if self.scale:
return x * self.scale_factor
else:
return x
def __repr__(self):
return "%s(embedding_dim=%d, vocab_size=%d)" % (
self.__class__.__name__,
self.embedding_dim,
self.vocab_size,
)
class SpatialEmbeddings(nn.Module):
"""
Simple Linear Projection Layer
(For encoder outputs to predict glosses)
"""
# pylint: disable=unused-argument
def __init__(
self,
embedding_dim: int,
input_size: int,
num_heads: int,
freeze: bool = False,
norm_type: str = "batch",
activation_type: str = "softsign",
scale: bool = False,
scale_factor: float = None,
**kwargs
):
"""
Create new embeddings for the vocabulary.
Use scaling for the Transformer.
:param embedding_dim:
:param input_size:
:param freeze: freeze the embeddings during training
"""
super().__init__()
self.embedding_dim = embedding_dim
self.input_size = input_size
self.ln = nn.Linear(self.input_size, self.embedding_dim)
self.norm_type = norm_type
if self.norm_type:
self.norm = MaskedNorm(
norm_type=norm_type, num_groups=num_heads, num_features=embedding_dim
)
self.activation_type = activation_type
if self.activation_type:
self.activation = get_activation(activation_type)
self.scale = scale
if self.scale:
if scale_factor:
self.scale_factor = scale_factor
else:
self.scale_factor = math.sqrt(self.embedding_dim)
if freeze:
freeze_params(self)
# pylint: disable=arguments-differ
def forward(self, x: Tensor, mask: Tensor) -> Tensor:
"""
:param mask: frame masks
:param x: input frame features
:return: embedded representation for `x`
"""
x = self.ln(x)
if self.norm_type:
x = self.norm(x, mask)
if self.activation_type:
x = self.activation(x)
if self.scale:
return x * self.scale_factor
else:
return x
def __repr__(self):
return "%s(embedding_dim=%d, input_size=%d)" % (
self.__class__.__name__,
self.embedding_dim,
self.input_size,
)
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
embeddings. :return: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class TimestepEmbedding(nn.Module):
def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"):
super().__init__()
self.linear_1 = nn.Linear(channel, time_embed_dim)
self.act = None
if act_fn == "silu":
self.act = nn.SiLU()
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
def forward(self, sample):
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
return t_emb