File size: 4,775 Bytes
f0c7f08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import sys

sys.path.append(os.getcwd())

from nets.base import TrainWrapperBaseClass
from nets.spg.s2glayers import Discriminator as D_S2G
from nets.spg.vqvae_1d import AE as s2g_body
import torch
import torch.optim as optim
import torch.nn.functional as F

from data_utils.lower_body import c_index, c_index_3d, c_index_6d


def separate_aa(aa):
    aa = aa[:, :, :].reshape(aa.shape[0], aa.shape[1], -1, 5)
    axis = F.normalize(aa[:, :, :, :3], dim=-1)
    angle = F.normalize(aa[:, :, :, 3:5], dim=-1)
    return axis, angle


class TrainWrapper(TrainWrapperBaseClass):
    '''
    a wrapper receving a batch from data_utils and calculate loss
    '''

    def __init__(self, args, config):
        self.args = args
        self.config = config
        self.device = torch.device(self.args.gpu)
        self.global_step = 0

        self.gan = False
        self.convert_to_6d = self.config.Data.pose.convert_to_6d
        self.preleng = self.config.Data.pose.pre_pose_length
        self.expression = self.config.Data.pose.expression
        self.epoch = 0
        self.init_params()
        self.num_classes = 4
        self.g = s2g_body(self.each_dim[1] + self.each_dim[2], embedding_dim=64, num_embeddings=0,
                          num_hiddens=1024, num_residual_layers=2, num_residual_hiddens=512).to(self.device)
        if self.gan:
            self.discriminator = D_S2G(
                pose_dim=110 + 64, pose=self.pose
            ).to(self.device)
        else:
            self.discriminator = None

        if self.convert_to_6d:
            self.c_index = c_index_6d
        else:
            self.c_index = c_index_3d

        super().__init__(args, config)

    def init_optimizer(self):

        self.g_optimizer = optim.Adam(
            self.g.parameters(),
            lr=self.config.Train.learning_rate.generator_learning_rate,
            betas=[0.9, 0.999]
        )

    def state_dict(self):
        model_state = {
            'g': self.g.state_dict(),
            'g_optim': self.g_optimizer.state_dict(),
            'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
            'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
        }
        return model_state


    def __call__(self, bat):
        # assert (not self.args.infer), "infer mode"
        self.global_step += 1

        total_loss = None
        loss_dict = {}

        aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)

        # id = bat['speaker'].to(self.device) - 20
        # id = F.one_hot(id, self.num_classes)

        poses = poses[:, self.c_index, :]
        gt_poses = poses[:, :, self.preleng:].permute(0, 2, 1)

        loss = 0
        loss_dict, loss = self.vq_train(gt_poses[:, :], 'g', self.g, loss_dict, loss)

        return total_loss, loss_dict

    def vq_train(self, gt, name, model, dict, total_loss, pre=None):
        x_recon = model(gt_poses=gt, pre_state=pre)
        loss, loss_dict = self.get_loss(pred_poses=x_recon, gt_poses=gt, pre=pre)
        # total_loss = total_loss + loss

        if name == 'g':
            optimizer_name = 'g_optimizer'

        optimizer = getattr(self, optimizer_name)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        for key in list(loss_dict.keys()):
            dict[name + key] = loss_dict.get(key, 0).item()
        return dict, total_loss

    def get_loss(self,
                 pred_poses,
                 gt_poses,
                 pre=None
                 ):
        loss_dict = {}


        rec_loss = torch.mean(torch.abs(pred_poses - gt_poses))
        v_pr = pred_poses[:, 1:] - pred_poses[:, :-1]
        v_gt = gt_poses[:, 1:] - gt_poses[:, :-1]
        velocity_loss = torch.mean(torch.abs(v_pr - v_gt))

        if pre is None:
            f0_vel = 0
        else:
            v0_pr = pred_poses[:, 0] - pre[:, -1]
            v0_gt = gt_poses[:, 0] - pre[:, -1]
            f0_vel = torch.mean(torch.abs(v0_pr - v0_gt))

        gen_loss = rec_loss + velocity_loss + f0_vel

        loss_dict['rec_loss'] = rec_loss
        loss_dict['velocity_loss'] = velocity_loss
        # loss_dict['e_q_loss'] = e_q_loss
        if pre is not None:
            loss_dict['f0_vel'] = f0_vel

        return gen_loss, loss_dict

    def load_state_dict(self, state_dict):
        self.g.load_state_dict(state_dict['g'])

    def extract(self, x):
        self.g.eval()
        if x.shape[2] > self.full_dim:
            if x.shape[2] == 239:
                x = x[:, :, 102:]
            x = x[:, :, self.c_index]
        feat = self.g.encode(x)
        return feat.transpose(1, 2), x