MultiTalk-Code / models /lib /base_models.py
ameerazam08's picture
Upload folder using huggingface_hub
6931c7b verified
from einops.layers.torch import Rearrange
import torch
import torch.nn as nn
from torch.nn import functional as F
from utils.base_model_util import *
import math
class Norm(nn.Module):
""" Norm Layer """
def __init__(self, fn, size):
super().__init__()
self.norm = nn.LayerNorm(size, eps=1e-5)
self.fn = fn
def forward(self, x_data):
if type(x_data) is dict:
x_norm = self.fn({'x_a':x_data['x_a'], 'x_b':self.norm(x_data['x_b'])})
return x_norm
else:
x, mask_info = x_data
x_norm, _ = self.fn((self.norm(x), mask_info))
return (x_norm, mask_info)
class Residual(nn.Module):
""" Residual Layer """
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x_data):
if type(x_data) is dict:
x_resid = self.fn(x_data)['x_b']
return {'x_a':x_data['x_a'], 'x_b':x_resid+x_data['x_b']}
else:
x, mask_info = x_data
x_resid, _ = self.fn(x_data)
return (x_resid + x, mask_info)
class MLP(nn.Module):
""" MLP Layer """
def __init__(self, in_dim, out_dim, hidden_dim):
super().__init__()
self.l1 = nn.Linear(in_dim, hidden_dim)
self.activation = get_activation("gelu")
self.l2 = nn.Linear(hidden_dim, out_dim)
def forward(self, x_data):
if type(x_data) is dict:
out = self.l2(self.activation(self.l1(x_data['x_b'])))
return {'x_a':x_data['x_a'], 'x_b':out}
else:
x, mask_info = x_data
out = self.l2(self.activation(self.l1(x)))
return (out, mask_info)
class CrossModalAttention(nn.Module):
""" Cross Modal Attention Layer
Given 2 modalities (a, b), computes the K,V from modality b and Q from
modality a.
"""
def __init__(self, in_dim, dim, heads=8, in_dim2=None):
super().__init__()
self.heads = heads
self.scale = dim**-0.5
if in_dim2 is not None:
self.to_kv = nn.Linear(in_dim2, in_dim2 * 2, bias=False)
else:
self.to_kv = nn.Linear(in_dim, dim * 2, bias=False)
self.to_q = nn.Linear(in_dim, dim, bias=False)
if in_dim2 is not None:
dim2 = int((in_dim + in_dim2*2) / 3)
else:
dim2 = dim
self.to_out = nn.Linear(dim2, dim)
self.rearrange_qkv = Rearrange(
"b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)
self.rearrange_out = Rearrange("b h n d -> b n (h d)")
def forward(self, x_data):
x_a = x_data['x_a']
x_b = x_data['x_b']
kv = self.to_kv(x_b)
q = self.to_q(x_a)
qkv = torch.cat((q, kv), dim=-1)
qkv = self.rearrange_qkv(qkv)
q = qkv[0]
k = qkv[1]
v = qkv[2]
dots = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale
attn = F.softmax(dots, dim=-1)
out = torch.einsum("bhij,bhjd->bhid", attn, v)
out = self.rearrange_out(out)
out = self.to_out(out)
return {'x_a':x_a, 'x_b':out}
class Attention(nn.Module):
""" Attention Layer """
def __init__(self, in_dim, dim, heads=8):
super().__init__()
self.heads = heads
self.scale = dim**-0.5
self.to_qkv = nn.Linear(in_dim, dim * 3, bias=False)
self.to_out = nn.Linear(dim, dim)
self.rearrange_qkv = Rearrange(
"b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)
self.rearrange_out = Rearrange("b h n d -> b n (h d)")
def forward(self, x_data):
x, mask_info = x_data
max_mask = mask_info['max_mask']
mask = mask_info['mask']
#
qkv = self.to_qkv(x)
qkv = self.rearrange_qkv(qkv)
q = qkv[0]
k = qkv[1]
v = qkv[2]
dots = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale
if max_mask is not None:
dots[:,:,:max_mask,:max_mask] = \
dots[:,:,:max_mask,:max_mask].masked_fill(mask == 0., float('-inf'))
attn = F.softmax(dots, dim=-1)
out = torch.einsum("bhij,bhjd->bhid", attn, v)
out = self.rearrange_out(out)
out = self.to_out(out)
return (out, mask_info)
class Transformer(nn.Module):
""" Transformer class
Parameters
----------
cross_modal : bool
if true, uses cross-modal attention layers, else is the vanilla Transformer
in_dim2 : int
specifies the feature size of the second modality if using cross_modal
"""
def __init__(self,
in_size=50,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
cross_modal=False,
in_dim2=None):
super().__init__()
blocks = []
attn = False
self.cross_modal = cross_modal
if cross_modal:
for i in range(num_hidden_layers):
blocks.extend([
Residual(Norm(CrossModalAttention(in_size, hidden_size,
heads=num_attention_heads,
in_dim2=in_dim2), hidden_size)),
Residual(Norm(MLP(hidden_size, hidden_size, intermediate_size),
hidden_size))
])
else:
for i in range(num_hidden_layers):
blocks.extend([
Residual(Norm(Attention(in_size, hidden_size,
heads=num_attention_heads), hidden_size)),
Residual(Norm(MLP(hidden_size, hidden_size, intermediate_size),
hidden_size))
])
self.net = torch.nn.Sequential(*blocks)
def forward(self, x_data):
if self.cross_modal:
assert type(x_data) is dict
x_data = self.net(x_data)
x = x_data['x_b']
else:
x, mask_info = x_data
x, _ = self.net((x, mask_info))
return x
class LinearEmbedding(nn.Module):
""" Linear Layer """
def __init__(self, size, dim):
super().__init__()
self.net = nn.Linear(size, dim)
def forward(self, x):
return self.net(x)
class AudioEmbedding(nn.Module):
""" Audio embedding layer
Parameters
----------
size : int
the input feature size of the audio embedding
dim : int
the desired output feature size for the audio embedding
quant_factor: int
specifies the number of max pool layers applied along the temporal dimension
version: str (default is 'v6')
specifies which version of the audio embedding to use
"""
def __init__(self, size, dim, quant_factor, version='v6'):
super().__init__()
self.proj = None
if version == 'v6':
print('MODEL V6')
self.net = nn.MaxPool1d(4)
layers = [nn.Sequential(nn.MaxPool1d(2))]
for _ in range(1, quant_factor):
layers += [nn.Sequential(
nn.MaxPool1d(2)
)]
self.squasher = nn.Sequential(*layers)
self.proj = nn.Linear(size,dim)
def forward(self, x):
x = self.net(x)
x = self.squasher(x)
if self.proj is not None:
x = self.proj(x.permute(0,2,1)).permute(0,2,1)
return x
class PositionEmbedding(nn.Module):
"""Postion Embedding Layer"""
def __init__(self, seq_length, dim):
super().__init__()
self.pos_embedding = nn.Parameter(torch.zeros(seq_length, dim))
def forward(self, x):
return x + self.pos_embedding
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
class CrossModalLayer(nn.Module):
"""Cross Modal Layer inspired by FACT [Li 2021]"""
def __init__(self, config):
super().__init__()
self.config = config
model_config = self.config['transformer']
self.transformer_layer = Transformer(
in_size=model_config['hidden_size'],
hidden_size=model_config['hidden_size'],
num_hidden_layers=model_config['num_hidden_layers'],
num_attention_heads=model_config['num_attention_heads'],
intermediate_size=model_config['intermediate_size'])
output_layer_config = self.config['output_layer']
self.cross_norm_layer = nn.LayerNorm(self.config['in_dim'])
self.cross_output_layer = nn.Linear(
self.config['in_dim'],
output_layer_config['out_dim'],
bias=False)
self.cross_pos_embedding = PositionEmbedding(
self.config["sequence_length"], self.config['in_dim'])
def forward(self, modal_a_sequences, modal_b_sequences, mask_info):
"""
Parameters
----------
modal_a_sequences : tensor
the first modality (e.g. Listener motion embedding)
modal_b_sequences : tensor
the second modality (e.g. Speaker motion+audio embedding)
mask_info: dict
specifies the binary mask that is applied to the Transformer attention
"""
_, _, modal_a_width = get_shape_list(modal_a_sequences)
merged_sequences = modal_a_sequences
if modal_b_sequences is not None:
_, _, modal_b_width = get_shape_list(modal_b_sequences)
if modal_a_width != modal_b_width:
raise ValueError(
"The modal_a hidden size (%d) should be the same with the modal_b"
"hidden size (%d)" % (modal_a_width, modal_b_width))
merged_sequences = torch.cat([merged_sequences, modal_b_sequences],
axis=1)
merged_sequences = self.cross_pos_embedding(merged_sequences)
merged_sequences = self.transformer_layer((merged_sequences, mask_info))
merged_sequences = self.cross_norm_layer(merged_sequences)
logits = self.cross_output_layer(merged_sequences)
return logits