from einops.layers.torch import Rearrange import torch import torch.nn as nn from torch.nn import functional as F from utils.base_model_util import * import math class Norm(nn.Module): """ Norm Layer """ def __init__(self, fn, size): super().__init__() self.norm = nn.LayerNorm(size, eps=1e-5) self.fn = fn def forward(self, x_data): if type(x_data) is dict: x_norm = self.fn({'x_a':x_data['x_a'], 'x_b':self.norm(x_data['x_b'])}) return x_norm else: x, mask_info = x_data x_norm, _ = self.fn((self.norm(x), mask_info)) return (x_norm, mask_info) class Residual(nn.Module): """ Residual Layer """ def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x_data): if type(x_data) is dict: x_resid = self.fn(x_data)['x_b'] return {'x_a':x_data['x_a'], 'x_b':x_resid+x_data['x_b']} else: x, mask_info = x_data x_resid, _ = self.fn(x_data) return (x_resid + x, mask_info) class MLP(nn.Module): """ MLP Layer """ def __init__(self, in_dim, out_dim, hidden_dim): super().__init__() self.l1 = nn.Linear(in_dim, hidden_dim) self.activation = get_activation("gelu") self.l2 = nn.Linear(hidden_dim, out_dim) def forward(self, x_data): if type(x_data) is dict: out = self.l2(self.activation(self.l1(x_data['x_b']))) return {'x_a':x_data['x_a'], 'x_b':out} else: x, mask_info = x_data out = self.l2(self.activation(self.l1(x))) return (out, mask_info) class CrossModalAttention(nn.Module): """ Cross Modal Attention Layer Given 2 modalities (a, b), computes the K,V from modality b and Q from modality a. """ def __init__(self, in_dim, dim, heads=8, in_dim2=None): super().__init__() self.heads = heads self.scale = dim**-0.5 if in_dim2 is not None: self.to_kv = nn.Linear(in_dim2, in_dim2 * 2, bias=False) else: self.to_kv = nn.Linear(in_dim, dim * 2, bias=False) self.to_q = nn.Linear(in_dim, dim, bias=False) if in_dim2 is not None: dim2 = int((in_dim + in_dim2*2) / 3) else: dim2 = dim self.to_out = nn.Linear(dim2, dim) self.rearrange_qkv = Rearrange( "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads) self.rearrange_out = Rearrange("b h n d -> b n (h d)") def forward(self, x_data): x_a = x_data['x_a'] x_b = x_data['x_b'] kv = self.to_kv(x_b) q = self.to_q(x_a) qkv = torch.cat((q, kv), dim=-1) qkv = self.rearrange_qkv(qkv) q = qkv[0] k = qkv[1] v = qkv[2] dots = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale attn = F.softmax(dots, dim=-1) out = torch.einsum("bhij,bhjd->bhid", attn, v) out = self.rearrange_out(out) out = self.to_out(out) return {'x_a':x_a, 'x_b':out} class Attention(nn.Module): """ Attention Layer """ def __init__(self, in_dim, dim, heads=8): super().__init__() self.heads = heads self.scale = dim**-0.5 self.to_qkv = nn.Linear(in_dim, dim * 3, bias=False) self.to_out = nn.Linear(dim, dim) self.rearrange_qkv = Rearrange( "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads) self.rearrange_out = Rearrange("b h n d -> b n (h d)") def forward(self, x_data): x, mask_info = x_data max_mask = mask_info['max_mask'] mask = mask_info['mask'] # qkv = self.to_qkv(x) qkv = self.rearrange_qkv(qkv) q = qkv[0] k = qkv[1] v = qkv[2] dots = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale if max_mask is not None: dots[:,:,:max_mask,:max_mask] = \ dots[:,:,:max_mask,:max_mask].masked_fill(mask == 0., float('-inf')) attn = F.softmax(dots, dim=-1) out = torch.einsum("bhij,bhjd->bhid", attn, v) out = self.rearrange_out(out) out = self.to_out(out) return (out, mask_info) class Transformer(nn.Module): """ Transformer class Parameters ---------- cross_modal : bool if true, uses cross-modal attention layers, else is the vanilla Transformer in_dim2 : int specifies the feature size of the second modality if using cross_modal """ def __init__(self, in_size=50, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, cross_modal=False, in_dim2=None): super().__init__() blocks = [] attn = False self.cross_modal = cross_modal if cross_modal: for i in range(num_hidden_layers): blocks.extend([ Residual(Norm(CrossModalAttention(in_size, hidden_size, heads=num_attention_heads, in_dim2=in_dim2), hidden_size)), Residual(Norm(MLP(hidden_size, hidden_size, intermediate_size), hidden_size)) ]) else: for i in range(num_hidden_layers): blocks.extend([ Residual(Norm(Attention(in_size, hidden_size, heads=num_attention_heads), hidden_size)), Residual(Norm(MLP(hidden_size, hidden_size, intermediate_size), hidden_size)) ]) self.net = torch.nn.Sequential(*blocks) def forward(self, x_data): if self.cross_modal: assert type(x_data) is dict x_data = self.net(x_data) x = x_data['x_b'] else: x, mask_info = x_data x, _ = self.net((x, mask_info)) return x class LinearEmbedding(nn.Module): """ Linear Layer """ def __init__(self, size, dim): super().__init__() self.net = nn.Linear(size, dim) def forward(self, x): return self.net(x) class AudioEmbedding(nn.Module): """ Audio embedding layer Parameters ---------- size : int the input feature size of the audio embedding dim : int the desired output feature size for the audio embedding quant_factor: int specifies the number of max pool layers applied along the temporal dimension version: str (default is 'v6') specifies which version of the audio embedding to use """ def __init__(self, size, dim, quant_factor, version='v6'): super().__init__() self.proj = None if version == 'v6': print('MODEL V6') self.net = nn.MaxPool1d(4) layers = [nn.Sequential(nn.MaxPool1d(2))] for _ in range(1, quant_factor): layers += [nn.Sequential( nn.MaxPool1d(2) )] self.squasher = nn.Sequential(*layers) self.proj = nn.Linear(size,dim) def forward(self, x): x = self.net(x) x = self.squasher(x) if self.proj is not None: x = self.proj(x.permute(0,2,1)).permute(0,2,1) return x class PositionEmbedding(nn.Module): """Postion Embedding Layer""" def __init__(self, seq_length, dim): super().__init__() self.pos_embedding = nn.Parameter(torch.zeros(seq_length, dim)) def forward(self, x): return x + self.pos_embedding class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[:x.size(0), :] return self.dropout(x) class CrossModalLayer(nn.Module): """Cross Modal Layer inspired by FACT [Li 2021]""" def __init__(self, config): super().__init__() self.config = config model_config = self.config['transformer'] self.transformer_layer = Transformer( in_size=model_config['hidden_size'], hidden_size=model_config['hidden_size'], num_hidden_layers=model_config['num_hidden_layers'], num_attention_heads=model_config['num_attention_heads'], intermediate_size=model_config['intermediate_size']) output_layer_config = self.config['output_layer'] self.cross_norm_layer = nn.LayerNorm(self.config['in_dim']) self.cross_output_layer = nn.Linear( self.config['in_dim'], output_layer_config['out_dim'], bias=False) self.cross_pos_embedding = PositionEmbedding( self.config["sequence_length"], self.config['in_dim']) def forward(self, modal_a_sequences, modal_b_sequences, mask_info): """ Parameters ---------- modal_a_sequences : tensor the first modality (e.g. Listener motion embedding) modal_b_sequences : tensor the second modality (e.g. Speaker motion+audio embedding) mask_info: dict specifies the binary mask that is applied to the Transformer attention """ _, _, modal_a_width = get_shape_list(modal_a_sequences) merged_sequences = modal_a_sequences if modal_b_sequences is not None: _, _, modal_b_width = get_shape_list(modal_b_sequences) if modal_a_width != modal_b_width: raise ValueError( "The modal_a hidden size (%d) should be the same with the modal_b" "hidden size (%d)" % (modal_a_width, modal_b_width)) merged_sequences = torch.cat([merged_sequences, modal_b_sequences], axis=1) merged_sequences = self.cross_pos_embedding(merged_sequences) merged_sequences = self.transformer_layer((merged_sequences, mask_info)) merged_sequences = self.cross_norm_layer(merged_sequences) logits = self.cross_output_layer(merged_sequences) return logits