|
import os |
|
import sys |
|
|
|
sys.path.append(os.getcwd()) |
|
|
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
|
|
|
|
|
|
|
|
def get_log(x): |
|
log = 0 |
|
while x > 1: |
|
if x % 2 == 0: |
|
x = x // 2 |
|
log += 1 |
|
else: |
|
raise ValueError('x is not a power of 2') |
|
|
|
return log |
|
|
|
|
|
class ConvNormRelu(nn.Module): |
|
''' |
|
(B,C_in,H,W) -> (B, C_out, H, W) |
|
there exist some kernel size that makes the result is not H/s |
|
#TODO: there might some problems with residual |
|
''' |
|
|
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
type='1d', |
|
leaky=False, |
|
downsample=False, |
|
kernel_size=None, |
|
stride=None, |
|
padding=None, |
|
p=0, |
|
groups=1, |
|
residual=False, |
|
norm='bn'): |
|
''' |
|
conv-bn-relu |
|
''' |
|
super(ConvNormRelu, self).__init__() |
|
self.residual = residual |
|
self.norm_type = norm |
|
|
|
|
|
|
|
if kernel_size is None and stride is None: |
|
if not downsample: |
|
kernel_size = 3 |
|
stride = 1 |
|
else: |
|
kernel_size = 4 |
|
stride = 2 |
|
|
|
if padding is None: |
|
if isinstance(kernel_size, int) and isinstance(stride, tuple): |
|
padding = tuple(int((kernel_size - st) / 2) for st in stride) |
|
elif isinstance(kernel_size, tuple) and isinstance(stride, int): |
|
padding = tuple(int((ks - stride) / 2) for ks in kernel_size) |
|
elif isinstance(kernel_size, tuple) and isinstance(stride, tuple): |
|
padding = tuple(int((ks - st) / 2) for ks, st in zip(kernel_size, stride)) |
|
else: |
|
padding = int((kernel_size - stride) / 2) |
|
|
|
if self.residual: |
|
if downsample: |
|
if type == '1d': |
|
self.residual_layer = nn.Sequential( |
|
nn.Conv1d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding |
|
) |
|
) |
|
elif type == '2d': |
|
self.residual_layer = nn.Sequential( |
|
nn.Conv2d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding |
|
) |
|
) |
|
else: |
|
if in_channels == out_channels: |
|
self.residual_layer = nn.Identity() |
|
else: |
|
if type == '1d': |
|
self.residual_layer = nn.Sequential( |
|
nn.Conv1d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding |
|
) |
|
) |
|
elif type == '2d': |
|
self.residual_layer = nn.Sequential( |
|
nn.Conv2d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding |
|
) |
|
) |
|
|
|
in_channels = in_channels * groups |
|
out_channels = out_channels * groups |
|
if type == '1d': |
|
self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, |
|
kernel_size=kernel_size, stride=stride, padding=padding, |
|
groups=groups) |
|
self.norm = nn.BatchNorm1d(out_channels) |
|
self.dropout = nn.Dropout(p=p) |
|
elif type == '2d': |
|
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, |
|
kernel_size=kernel_size, stride=stride, padding=padding, |
|
groups=groups) |
|
self.norm = nn.BatchNorm2d(out_channels) |
|
self.dropout = nn.Dropout2d(p=p) |
|
if norm == 'gn': |
|
self.norm = nn.GroupNorm(2, out_channels) |
|
elif norm == 'ln': |
|
self.norm = nn.LayerNorm(out_channels) |
|
if leaky: |
|
self.relu = nn.LeakyReLU(negative_slope=0.2) |
|
else: |
|
self.relu = nn.ReLU() |
|
|
|
def forward(self, x, **kwargs): |
|
if self.norm_type == 'ln': |
|
out = self.dropout(self.conv(x)) |
|
out = self.norm(out.transpose(1,2)).transpose(1,2) |
|
else: |
|
out = self.norm(self.dropout(self.conv(x))) |
|
if self.residual: |
|
residual = self.residual_layer(x) |
|
out += residual |
|
return self.relu(out) |
|
|
|
|
|
class UNet1D(nn.Module): |
|
def __init__(self, |
|
input_channels, |
|
output_channels, |
|
max_depth=5, |
|
kernel_size=None, |
|
stride=None, |
|
p=0, |
|
groups=1): |
|
super(UNet1D, self).__init__() |
|
self.pre_downsampling_conv = nn.ModuleList([]) |
|
self.conv1 = nn.ModuleList([]) |
|
self.conv2 = nn.ModuleList([]) |
|
self.upconv = nn.Upsample(scale_factor=2, mode='nearest') |
|
self.max_depth = max_depth |
|
self.groups = groups |
|
|
|
self.pre_downsampling_conv.append(ConvNormRelu(input_channels, output_channels, |
|
type='1d', leaky=True, downsample=False, |
|
kernel_size=kernel_size, stride=stride, p=p, groups=groups)) |
|
self.pre_downsampling_conv.append(ConvNormRelu(output_channels, output_channels, |
|
type='1d', leaky=True, downsample=False, |
|
kernel_size=kernel_size, stride=stride, p=p, groups=groups)) |
|
|
|
for i in range(self.max_depth): |
|
self.conv1.append(ConvNormRelu(output_channels, output_channels, |
|
type='1d', leaky=True, downsample=True, |
|
kernel_size=kernel_size, stride=stride, p=p, groups=groups)) |
|
|
|
for i in range(self.max_depth): |
|
self.conv2.append(ConvNormRelu(output_channels, output_channels, |
|
type='1d', leaky=True, downsample=False, |
|
kernel_size=kernel_size, stride=stride, p=p, groups=groups)) |
|
|
|
def forward(self, x): |
|
|
|
input_size = x.shape[-1] |
|
|
|
assert get_log( |
|
input_size) >= self.max_depth, 'num_frames must be a power of 2 and its power must be greater than max_depth' |
|
|
|
x = nn.Sequential(*self.pre_downsampling_conv)(x) |
|
|
|
residuals = [] |
|
residuals.append(x) |
|
for i, conv1 in enumerate(self.conv1): |
|
x = conv1(x) |
|
if i < self.max_depth - 1: |
|
residuals.append(x) |
|
|
|
for i, conv2 in enumerate(self.conv2): |
|
x = self.upconv(x) + residuals[self.max_depth - i - 1] |
|
x = conv2(x) |
|
|
|
return x |
|
|
|
|
|
class UNet2D(nn.Module): |
|
def __init__(self): |
|
super(UNet2D, self).__init__() |
|
raise NotImplementedError('2D Unet is wierd') |
|
|
|
|
|
class AudioPoseEncoder1D(nn.Module): |
|
''' |
|
(B, C, T) -> (B, C*2, T) -> ... -> (B, C_out, T) |
|
''' |
|
|
|
def __init__(self, |
|
C_in, |
|
C_out, |
|
kernel_size=None, |
|
stride=None, |
|
min_layer_nums=None |
|
): |
|
super(AudioPoseEncoder1D, self).__init__() |
|
self.C_in = C_in |
|
self.C_out = C_out |
|
|
|
conv_layers = nn.ModuleList([]) |
|
cur_C = C_in |
|
num_layers = 0 |
|
while cur_C < self.C_out: |
|
conv_layers.append(ConvNormRelu( |
|
in_channels=cur_C, |
|
out_channels=cur_C * 2, |
|
kernel_size=kernel_size, |
|
stride=stride |
|
)) |
|
cur_C *= 2 |
|
num_layers += 1 |
|
|
|
if (cur_C != C_out) or (min_layer_nums is not None and num_layers < min_layer_nums): |
|
while (cur_C != C_out) or num_layers < min_layer_nums: |
|
conv_layers.append(ConvNormRelu( |
|
in_channels=cur_C, |
|
out_channels=C_out, |
|
kernel_size=kernel_size, |
|
stride=stride |
|
)) |
|
num_layers += 1 |
|
cur_C = C_out |
|
|
|
self.conv_layers = nn.Sequential(*conv_layers) |
|
|
|
def forward(self, x): |
|
''' |
|
x: (B, C, T) |
|
''' |
|
x = self.conv_layers(x) |
|
return x |
|
|
|
|
|
class AudioPoseEncoder2D(nn.Module): |
|
''' |
|
(B, C, T) -> (B, 1, C, T) -> ... -> (B, C_out, T) |
|
''' |
|
|
|
def __init__(self): |
|
raise NotImplementedError |
|
|
|
|
|
class AudioPoseEncoderRNN(nn.Module): |
|
''' |
|
(B, C, T)->(B, T, C)->(B, T, C_out)->(B, C_out, T) |
|
''' |
|
|
|
def __init__(self, |
|
C_in, |
|
hidden_size, |
|
num_layers, |
|
rnn_cell='gru', |
|
bidirectional=False |
|
): |
|
super(AudioPoseEncoderRNN, self).__init__() |
|
if rnn_cell == 'gru': |
|
self.cell = nn.GRU(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, |
|
bidirectional=bidirectional) |
|
elif rnn_cell == 'lstm': |
|
self.cell = nn.LSTM(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, |
|
bidirectional=bidirectional) |
|
else: |
|
raise ValueError('invalid rnn cell:%s' % (rnn_cell)) |
|
|
|
def forward(self, x, state=None): |
|
|
|
x = x.permute(0, 2, 1) |
|
x, state = self.cell(x, state) |
|
x = x.permute(0, 2, 1) |
|
|
|
return x |
|
|
|
|
|
class AudioPoseEncoderGraph(nn.Module): |
|
''' |
|
(B, C, T)->(B, 2, V, T)->(B, 2, T, V)->(B, D, T, V) |
|
''' |
|
|
|
def __init__(self, |
|
layers_config, |
|
A, |
|
residual, |
|
local_bn=False, |
|
share_weights=False |
|
) -> None: |
|
super().__init__() |
|
self.A = A |
|
self.num_joints = A.shape[1] |
|
self.num_parts = A.shape[0] |
|
self.C_in = layers_config[0][0] |
|
self.C_out = layers_config[-1][1] |
|
|
|
self.conv_layers = nn.ModuleList([ |
|
GraphConvNormRelu( |
|
C_in=c_in, |
|
C_out=c_out, |
|
A=self.A, |
|
residual=residual, |
|
local_bn=local_bn, |
|
kernel_size=k, |
|
share_weights=share_weights |
|
) for (c_in, c_out, k) in layers_config |
|
]) |
|
|
|
self.conv_layers = nn.Sequential(*self.conv_layers) |
|
|
|
def forward(self, x): |
|
''' |
|
x: (B, C, T), C should be num_joints*D |
|
output: (B, D, T, V) |
|
''' |
|
B, C, T = x.shape |
|
x = x.view(B, self.num_joints, self.C_in, T) |
|
x = x.permute(0, 2, 3, 1) |
|
assert x.shape[1] == self.C_in |
|
|
|
x_conved = self.conv_layers(x) |
|
|
|
|
|
|
|
return x_conved |
|
|
|
|
|
class SeqEncoder2D(nn.Module): |
|
''' |
|
seq_encoder, encoding a seq to a vector |
|
(B, C, T)->(B, 2, V, T)->(B, 2, T, V) -> (B, 32, )->...->(B, C_out) |
|
''' |
|
|
|
def __init__(self, |
|
C_in, |
|
T_in, |
|
C_out, |
|
num_joints, |
|
min_layer_num=None, |
|
residual=False |
|
): |
|
super(SeqEncoder2D, self).__init__() |
|
self.C_in = C_in |
|
self.C_out = C_out |
|
self.T_in = T_in |
|
self.num_joints = num_joints |
|
|
|
conv_layers = nn.ModuleList([]) |
|
conv_layers.append(ConvNormRelu( |
|
in_channels=C_in, |
|
out_channels=32, |
|
type='2d', |
|
residual=residual |
|
)) |
|
|
|
cur_C = 32 |
|
cur_H = T_in |
|
cur_W = num_joints |
|
num_layers = 1 |
|
while (cur_C < C_out) or (cur_H > 1) or (cur_W > 1): |
|
ks = [3, 3] |
|
st = [1, 1] |
|
|
|
if cur_H > 1: |
|
if cur_H > 4: |
|
ks[0] = 4 |
|
st[0] = 2 |
|
else: |
|
ks[0] = cur_H |
|
st[0] = cur_H |
|
if cur_W > 1: |
|
if cur_W > 4: |
|
ks[1] = 4 |
|
st[1] = 2 |
|
else: |
|
ks[1] = cur_W |
|
st[1] = cur_W |
|
|
|
conv_layers.append(ConvNormRelu( |
|
in_channels=cur_C, |
|
out_channels=min(C_out, cur_C * 2), |
|
type='2d', |
|
kernel_size=tuple(ks), |
|
stride=tuple(st), |
|
residual=residual |
|
)) |
|
cur_C = min(cur_C * 2, C_out) |
|
if cur_H > 1: |
|
if cur_H > 4: |
|
cur_H //= 2 |
|
else: |
|
cur_H = 1 |
|
if cur_W > 1: |
|
if cur_W > 4: |
|
cur_W //= 2 |
|
else: |
|
cur_W = 1 |
|
num_layers += 1 |
|
|
|
if min_layer_num is not None and (num_layers < min_layer_num): |
|
while num_layers < min_layer_num: |
|
conv_layers.append(ConvNormRelu( |
|
in_channels=C_out, |
|
out_channels=C_out, |
|
type='2d', |
|
kernel_size=1, |
|
stride=1, |
|
residual=residual |
|
)) |
|
num_layers += 1 |
|
|
|
self.conv_layers = nn.Sequential(*conv_layers) |
|
self.num_layers = num_layers |
|
|
|
def forward(self, x): |
|
B, C, T = x.shape |
|
x = x.view(B, self.num_joints, self.C_in, T) |
|
x = x.permute(0, 2, 3, 1) |
|
assert x.shape[1] == self.C_in and x.shape[-1] == self.num_joints |
|
|
|
x = self.conv_layers(x) |
|
return x.squeeze() |
|
|
|
|
|
class SeqEncoder1D(nn.Module): |
|
''' |
|
(B, C, T)->(B, D) |
|
''' |
|
|
|
def __init__(self, |
|
C_in, |
|
C_out, |
|
T_in, |
|
min_layer_nums=None |
|
): |
|
super(SeqEncoder1D, self).__init__() |
|
conv_layers = nn.ModuleList([]) |
|
cur_C = C_in |
|
cur_T = T_in |
|
self.num_layers = 0 |
|
while (cur_C < C_out) or (cur_T > 1): |
|
ks = 3 |
|
st = 1 |
|
if cur_T > 1: |
|
if cur_T > 4: |
|
ks = 4 |
|
st = 2 |
|
else: |
|
ks = cur_T |
|
st = cur_T |
|
|
|
conv_layers.append(ConvNormRelu( |
|
in_channels=cur_C, |
|
out_channels=min(C_out, cur_C * 2), |
|
type='1d', |
|
kernel_size=ks, |
|
stride=st |
|
)) |
|
cur_C = min(cur_C * 2, C_out) |
|
if cur_T > 1: |
|
if cur_T > 4: |
|
cur_T = cur_T // 2 |
|
else: |
|
cur_T = 1 |
|
self.num_layers += 1 |
|
|
|
if min_layer_nums is not None and (self.num_layers < min_layer_nums): |
|
while self.num_layers < min_layer_nums: |
|
conv_layers.append(ConvNormRelu( |
|
in_channels=C_out, |
|
out_channels=C_out, |
|
type='1d', |
|
kernel_size=1, |
|
stride=1 |
|
)) |
|
self.num_layers += 1 |
|
self.conv_layers = nn.Sequential(*conv_layers) |
|
|
|
def forward(self, x): |
|
x = self.conv_layers(x) |
|
return x.squeeze() |
|
|
|
|
|
class SeqEncoderRNN(nn.Module): |
|
''' |
|
(B, C, T) -> (B, T, C) -> (B, D) |
|
LSTM/GRU-FC |
|
''' |
|
|
|
def __init__(self, |
|
hidden_size, |
|
in_size, |
|
num_rnn_layers, |
|
rnn_cell='gru', |
|
bidirectional=False |
|
): |
|
super(SeqEncoderRNN, self).__init__() |
|
self.hidden_size = hidden_size |
|
self.in_size = in_size |
|
self.num_rnn_layers = num_rnn_layers |
|
self.bidirectional = bidirectional |
|
|
|
if rnn_cell == 'gru': |
|
self.cell = nn.GRU(input_size=self.in_size, hidden_size=self.hidden_size, num_layers=self.num_rnn_layers, |
|
batch_first=True, bidirectional=bidirectional) |
|
elif rnn_cell == 'lstm': |
|
self.cell = nn.LSTM(input_size=self.in_size, hidden_size=self.hidden_size, num_layers=self.num_rnn_layers, |
|
batch_first=True, bidirectional=bidirectional) |
|
|
|
def forward(self, x, state=None): |
|
|
|
x = x.permute(0, 2, 1) |
|
B, T, C = x.shape |
|
x, _ = self.cell(x, state) |
|
if self.bidirectional: |
|
out = torch.cat([x[:, -1, :self.hidden_size], x[:, 0, self.hidden_size:]], dim=-1) |
|
else: |
|
out = x[:, -1, :] |
|
assert out.shape[0] == B |
|
return out |
|
|
|
|
|
class SeqEncoderGraph(nn.Module): |
|
''' |
|
''' |
|
|
|
def __init__(self, |
|
embedding_size, |
|
layer_configs, |
|
residual, |
|
local_bn, |
|
A, |
|
T, |
|
share_weights=False |
|
) -> None: |
|
super().__init__() |
|
|
|
self.C_in = layer_configs[0][0] |
|
self.C_out = embedding_size |
|
|
|
self.num_joints = A.shape[1] |
|
|
|
self.graph_encoder = AudioPoseEncoderGraph( |
|
layers_config=layer_configs, |
|
A=A, |
|
residual=residual, |
|
local_bn=local_bn, |
|
share_weights=share_weights |
|
) |
|
|
|
cur_C = layer_configs[-1][1] |
|
self.spatial_pool = ConvNormRelu( |
|
in_channels=cur_C, |
|
out_channels=cur_C, |
|
type='2d', |
|
kernel_size=(1, self.num_joints), |
|
stride=(1, 1), |
|
padding=(0, 0) |
|
) |
|
|
|
temporal_pool = nn.ModuleList([]) |
|
cur_H = T |
|
num_layers = 0 |
|
self.temporal_conv_info = [] |
|
while cur_C < self.C_out or cur_H > 1: |
|
self.temporal_conv_info.append(cur_C) |
|
ks = [3, 1] |
|
st = [1, 1] |
|
|
|
if cur_H > 1: |
|
if cur_H > 4: |
|
ks[0] = 4 |
|
st[0] = 2 |
|
else: |
|
ks[0] = cur_H |
|
st[0] = cur_H |
|
|
|
temporal_pool.append(ConvNormRelu( |
|
in_channels=cur_C, |
|
out_channels=min(self.C_out, cur_C * 2), |
|
type='2d', |
|
kernel_size=tuple(ks), |
|
stride=tuple(st) |
|
)) |
|
cur_C = min(cur_C * 2, self.C_out) |
|
|
|
if cur_H > 1: |
|
if cur_H > 4: |
|
cur_H //= 2 |
|
else: |
|
cur_H = 1 |
|
|
|
num_layers += 1 |
|
|
|
self.temporal_pool = nn.Sequential(*temporal_pool) |
|
print("graph seq encoder info: temporal pool:", self.temporal_conv_info) |
|
self.num_layers = num_layers |
|
|
|
|
|
def forward(self, x): |
|
''' |
|
x: (B, C, T) |
|
''' |
|
B, C, T = x.shape |
|
x = self.graph_encoder(x) |
|
x = self.spatial_pool(x) |
|
x = self.temporal_pool(x) |
|
x = x.view(B, self.C_out) |
|
|
|
return x |
|
|
|
|
|
class SeqDecoder2D(nn.Module): |
|
''' |
|
(B, D)->(B, D, 1, 1)->(B, C_out, C, T)->(B, C_out, T) |
|
''' |
|
|
|
def __init__(self): |
|
super(SeqDecoder2D, self).__init__() |
|
raise NotImplementedError |
|
|
|
|
|
class SeqDecoder1D(nn.Module): |
|
''' |
|
(B, D)->(B, D, 1)->...->(B, C_out, T) |
|
''' |
|
|
|
def __init__(self, |
|
D_in, |
|
C_out, |
|
T_out, |
|
min_layer_num=None |
|
): |
|
super(SeqDecoder1D, self).__init__() |
|
self.T_out = T_out |
|
self.min_layer_num = min_layer_num |
|
|
|
cur_t = 1 |
|
|
|
self.pre_conv = ConvNormRelu( |
|
in_channels=D_in, |
|
out_channels=C_out, |
|
type='1d' |
|
) |
|
self.num_layers = 1 |
|
self.upconv = nn.Upsample(scale_factor=2, mode='nearest') |
|
self.conv_layers = nn.ModuleList([]) |
|
cur_t *= 2 |
|
while cur_t <= T_out: |
|
self.conv_layers.append(ConvNormRelu( |
|
in_channels=C_out, |
|
out_channels=C_out, |
|
type='1d' |
|
)) |
|
cur_t *= 2 |
|
self.num_layers += 1 |
|
|
|
post_conv = nn.ModuleList([ConvNormRelu( |
|
in_channels=C_out, |
|
out_channels=C_out, |
|
type='1d' |
|
)]) |
|
self.num_layers += 1 |
|
if min_layer_num is not None and self.num_layers < min_layer_num: |
|
while self.num_layers < min_layer_num: |
|
post_conv.append(ConvNormRelu( |
|
in_channels=C_out, |
|
out_channels=C_out, |
|
type='1d' |
|
)) |
|
self.num_layers += 1 |
|
self.post_conv = nn.Sequential(*post_conv) |
|
|
|
def forward(self, x): |
|
|
|
x = x.unsqueeze(-1) |
|
x = self.pre_conv(x) |
|
for conv in self.conv_layers: |
|
x = self.upconv(x) |
|
x = conv(x) |
|
|
|
x = torch.nn.functional.interpolate(x, size=self.T_out, mode='nearest') |
|
x = self.post_conv(x) |
|
return x |
|
|
|
|
|
class SeqDecoderRNN(nn.Module): |
|
''' |
|
(B, D)->(B, C_out, T) |
|
''' |
|
|
|
def __init__(self, |
|
hidden_size, |
|
C_out, |
|
T_out, |
|
num_layers, |
|
rnn_cell='gru' |
|
): |
|
super(SeqDecoderRNN, self).__init__() |
|
self.num_steps = T_out |
|
if rnn_cell == 'gru': |
|
self.cell = nn.GRU(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, |
|
bidirectional=False) |
|
elif rnn_cell == 'lstm': |
|
self.cell = nn.LSTM(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, |
|
bidirectional=False) |
|
else: |
|
raise ValueError('invalid rnn cell:%s' % (rnn_cell)) |
|
|
|
self.fc = nn.Linear(hidden_size, C_out) |
|
|
|
def forward(self, hidden, frame_0): |
|
frame_0 = frame_0.permute(0, 2, 1) |
|
dec_input = frame_0 |
|
outputs = [] |
|
for i in range(self.num_steps): |
|
frame_out, hidden = self.cell(dec_input, hidden) |
|
frame_out = self.fc(frame_out) |
|
dec_input = frame_out |
|
outputs.append(frame_out) |
|
output = torch.cat(outputs, dim=1) |
|
return output.permute(0, 2, 1) |
|
|
|
|
|
class SeqTranslator2D(nn.Module): |
|
''' |
|
(B, C, T)->(B, 1, C, T)-> ... -> (B, 1, C_out, T_out) |
|
''' |
|
|
|
def __init__(self, |
|
C_in=64, |
|
C_out=108, |
|
T_in=75, |
|
T_out=25, |
|
residual=True |
|
): |
|
super(SeqTranslator2D, self).__init__() |
|
print("Warning: hard coded") |
|
self.C_in = C_in |
|
self.C_out = C_out |
|
self.T_in = T_in |
|
self.T_out = T_out |
|
self.residual = residual |
|
|
|
self.conv_layers = nn.Sequential( |
|
ConvNormRelu(1, 32, '2d', kernel_size=5, stride=1), |
|
ConvNormRelu(32, 32, '2d', kernel_size=5, stride=1, residual=self.residual), |
|
ConvNormRelu(32, 32, '2d', kernel_size=5, stride=1, residual=self.residual), |
|
|
|
ConvNormRelu(32, 64, '2d', kernel_size=5, stride=(4, 3)), |
|
ConvNormRelu(64, 64, '2d', kernel_size=5, stride=1, residual=self.residual), |
|
ConvNormRelu(64, 64, '2d', kernel_size=5, stride=1, residual=self.residual), |
|
|
|
ConvNormRelu(64, 128, '2d', kernel_size=5, stride=(4, 1)), |
|
ConvNormRelu(128, 108, '2d', kernel_size=3, stride=(4, 1)), |
|
ConvNormRelu(108, 108, '2d', kernel_size=(1, 3), stride=1, residual=self.residual), |
|
|
|
ConvNormRelu(108, 108, '2d', kernel_size=(1, 3), stride=1, residual=self.residual), |
|
ConvNormRelu(108, 108, '2d', kernel_size=(1, 3), stride=1), |
|
) |
|
|
|
def forward(self, x): |
|
assert len(x.shape) == 3 and x.shape[1] == self.C_in and x.shape[2] == self.T_in |
|
x = x.view(x.shape[0], 1, x.shape[1], x.shape[2]) |
|
x = self.conv_layers(x) |
|
x = x.squeeze(2) |
|
return x |
|
|
|
|
|
class SeqTranslator1D(nn.Module): |
|
''' |
|
(B, C, T)->(B, C_out, T) |
|
''' |
|
|
|
def __init__(self, |
|
C_in, |
|
C_out, |
|
kernel_size=None, |
|
stride=None, |
|
min_layers_num=None, |
|
residual=True, |
|
norm='bn' |
|
): |
|
super(SeqTranslator1D, self).__init__() |
|
|
|
conv_layers = nn.ModuleList([]) |
|
conv_layers.append(ConvNormRelu( |
|
in_channels=C_in, |
|
out_channels=C_out, |
|
type='1d', |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
residual=residual, |
|
norm=norm |
|
)) |
|
self.num_layers = 1 |
|
if min_layers_num is not None and self.num_layers < min_layers_num: |
|
while self.num_layers < min_layers_num: |
|
conv_layers.append(ConvNormRelu( |
|
in_channels=C_out, |
|
out_channels=C_out, |
|
type='1d', |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
residual=residual, |
|
norm=norm |
|
)) |
|
self.num_layers += 1 |
|
self.conv_layers = nn.Sequential(*conv_layers) |
|
|
|
def forward(self, x): |
|
return self.conv_layers(x) |
|
|
|
|
|
class SeqTranslatorRNN(nn.Module): |
|
''' |
|
(B, C, T)->(B, C_out, T) |
|
LSTM-FC |
|
''' |
|
|
|
def __init__(self, |
|
C_in, |
|
C_out, |
|
hidden_size, |
|
num_layers, |
|
rnn_cell='gru' |
|
): |
|
super(SeqTranslatorRNN, self).__init__() |
|
|
|
if rnn_cell == 'gru': |
|
self.enc_cell = nn.GRU(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, |
|
bidirectional=False) |
|
self.dec_cell = nn.GRU(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, |
|
bidirectional=False) |
|
elif rnn_cell == 'lstm': |
|
self.enc_cell = nn.LSTM(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, |
|
bidirectional=False) |
|
self.dec_cell = nn.LSTM(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, |
|
bidirectional=False) |
|
else: |
|
raise ValueError('invalid rnn cell:%s' % (rnn_cell)) |
|
|
|
self.fc = nn.Linear(hidden_size, C_out) |
|
|
|
def forward(self, x, frame_0): |
|
|
|
num_steps = x.shape[-1] |
|
x = x.permute(0, 2, 1) |
|
frame_0 = frame_0.permute(0, 2, 1) |
|
_, hidden = self.enc_cell(x, None) |
|
|
|
outputs = [] |
|
for i in range(num_steps): |
|
inputs = frame_0 |
|
output_frame, hidden = self.dec_cell(inputs, hidden) |
|
output_frame = self.fc(output_frame) |
|
frame_0 = output_frame |
|
outputs.append(output_frame) |
|
outputs = torch.cat(outputs, dim=1) |
|
return outputs.permute(0, 2, 1) |
|
|
|
|
|
class ResBlock(nn.Module): |
|
def __init__(self, |
|
input_dim, |
|
fc_dim, |
|
afn, |
|
nfn |
|
): |
|
''' |
|
afn: activation fn |
|
nfn: normalization fn |
|
''' |
|
super(ResBlock, self).__init__() |
|
|
|
self.input_dim = input_dim |
|
self.fc_dim = fc_dim |
|
self.afn = afn |
|
self.nfn = nfn |
|
|
|
if self.afn != 'relu': |
|
raise ValueError('Wrong') |
|
|
|
if self.nfn == 'layer_norm': |
|
raise ValueError('wrong') |
|
|
|
self.layers = nn.Sequential( |
|
nn.Linear(self.input_dim, self.fc_dim // 2), |
|
nn.ReLU(), |
|
nn.Linear(self.fc_dim // 2, self.fc_dim // 2), |
|
nn.ReLU(), |
|
nn.Linear(self.fc_dim // 2, self.fc_dim), |
|
nn.ReLU() |
|
) |
|
|
|
self.shortcut_layer = nn.Sequential( |
|
nn.Linear(self.input_dim, self.fc_dim), |
|
nn.ReLU(), |
|
) |
|
|
|
def forward(self, inputs): |
|
return self.layers(inputs) + self.shortcut_layer(inputs) |
|
|
|
|
|
class AudioEncoder(nn.Module): |
|
def __init__(self, channels, padding=3, kernel_size=8, conv_stride=2, conv_pool=None, augmentation=False): |
|
super(AudioEncoder, self).__init__() |
|
self.in_channels = channels[0] |
|
self.augmentation = augmentation |
|
|
|
model = [] |
|
acti = nn.LeakyReLU(0.2) |
|
|
|
nr_layer = len(channels) - 1 |
|
|
|
for i in range(nr_layer): |
|
if conv_pool is None: |
|
model.append(nn.ReflectionPad1d(padding)) |
|
model.append(nn.Conv1d(channels[i], channels[i + 1], kernel_size=kernel_size, stride=conv_stride)) |
|
model.append(acti) |
|
else: |
|
model.append(nn.ReflectionPad1d(padding)) |
|
model.append(nn.Conv1d(channels[i], channels[i + 1], kernel_size=kernel_size, stride=conv_stride)) |
|
model.append(acti) |
|
model.append(conv_pool(kernel_size=2, stride=2)) |
|
|
|
if self.augmentation: |
|
model.append( |
|
nn.Conv1d(channels[-1], channels[-1], kernel_size=kernel_size, stride=conv_stride) |
|
) |
|
model.append(acti) |
|
|
|
self.model = nn.Sequential(*model) |
|
|
|
def forward(self, x): |
|
|
|
x = x[:, :self.in_channels, :] |
|
x = self.model(x) |
|
return x |
|
|
|
|
|
class AudioDecoder(nn.Module): |
|
def __init__(self, channels, kernel_size=7, ups=25): |
|
super(AudioDecoder, self).__init__() |
|
|
|
model = [] |
|
pad = (kernel_size - 1) // 2 |
|
acti = nn.LeakyReLU(0.2) |
|
|
|
for i in range(len(channels) - 2): |
|
model.append(nn.Upsample(scale_factor=2, mode='nearest')) |
|
model.append(nn.ReflectionPad1d(pad)) |
|
model.append(nn.Conv1d(channels[i], channels[i + 1], |
|
kernel_size=kernel_size, stride=1)) |
|
if i == 0 or i == 1: |
|
model.append(nn.Dropout(p=0.2)) |
|
if not i == len(channels) - 2: |
|
model.append(acti) |
|
|
|
model.append(nn.Upsample(size=ups, mode='nearest')) |
|
model.append(nn.ReflectionPad1d(pad)) |
|
model.append(nn.Conv1d(channels[-2], channels[-1], |
|
kernel_size=kernel_size, stride=1)) |
|
|
|
self.model = nn.Sequential(*model) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
|
|
class Audio2Pose(nn.Module): |
|
def __init__(self, pose_dim, embed_size, augmentation, ups=25): |
|
super(Audio2Pose, self).__init__() |
|
self.pose_dim = pose_dim |
|
self.embed_size = embed_size |
|
self.augmentation = augmentation |
|
|
|
self.aud_enc = AudioEncoder(channels=[13, 64, 128, 256], padding=2, kernel_size=7, conv_stride=1, |
|
conv_pool=nn.AvgPool1d, augmentation=self.augmentation) |
|
if self.augmentation: |
|
self.aud_dec = AudioDecoder(channels=[512, 256, 128, pose_dim]) |
|
else: |
|
self.aud_dec = AudioDecoder(channels=[256, 256, 128, pose_dim], ups=ups) |
|
|
|
if self.augmentation: |
|
self.pose_enc = nn.Sequential( |
|
nn.Linear(self.embed_size // 2, 256), |
|
nn.LayerNorm(256) |
|
) |
|
|
|
def forward(self, audio_feat, dec_input=None): |
|
|
|
B = audio_feat.shape[0] |
|
|
|
aud_embed = self.aud_enc.forward(audio_feat) |
|
|
|
if self.augmentation: |
|
dec_input = dec_input.squeeze(0) |
|
dec_embed = self.pose_enc(dec_input) |
|
dec_embed = dec_embed.unsqueeze(2) |
|
dec_embed = dec_embed.expand(dec_embed.shape[0], dec_embed.shape[1], aud_embed.shape[-1]) |
|
aud_embed = torch.cat([aud_embed, dec_embed], dim=1) |
|
|
|
out = self.aud_dec.forward(aud_embed) |
|
return out |
|
|
|
|
|
if __name__ == '__main__': |
|
import numpy as np |
|
import os |
|
import sys |
|
|
|
test_model = SeqEncoder2D( |
|
C_in=2, |
|
T_in=25, |
|
C_out=512, |
|
num_joints=54, |
|
) |
|
print(test_model.num_layers) |
|
|
|
input = torch.randn((64, 108, 25)) |
|
output = test_model(input) |
|
print(output.shape) |