shadowcun commited on
Commit
cdf3959
·
1 Parent(s): 9ab094a

new version of sadtalker

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. src/__pycache__/generate_batch.cpython-38.pyc +0 -0
  2. src/__pycache__/generate_facerender_batch.cpython-38.pyc +0 -0
  3. src/__pycache__/test_audio2coeff.cpython-38.pyc +0 -0
  4. src/audio2exp_models/__pycache__/audio2exp.cpython-38.pyc +0 -0
  5. src/audio2exp_models/__pycache__/networks.cpython-38.pyc +0 -0
  6. src/audio2pose_models/__pycache__/audio2pose.cpython-38.pyc +0 -0
  7. src/audio2pose_models/__pycache__/audio_encoder.cpython-38.pyc +0 -0
  8. src/audio2pose_models/__pycache__/cvae.cpython-38.pyc +0 -0
  9. src/audio2pose_models/__pycache__/discriminator.cpython-38.pyc +0 -0
  10. src/audio2pose_models/__pycache__/networks.cpython-38.pyc +0 -0
  11. src/audio2pose_models/__pycache__/res_unet.cpython-38.pyc +0 -0
  12. src/{src/config → config}/facerender_pirender.yaml +0 -0
  13. src/face3d/models/__pycache__/__init__.cpython-38.pyc +0 -0
  14. src/face3d/models/__pycache__/base_model.cpython-38.pyc +0 -0
  15. src/face3d/models/__pycache__/networks.cpython-38.pyc +0 -0
  16. src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-38.pyc +0 -0
  17. src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-38.pyc +0 -0
  18. src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-38.pyc +0 -0
  19. src/face3d/util/__pycache__/__init__.cpython-38.pyc +0 -0
  20. src/face3d/util/__pycache__/load_mats.cpython-38.pyc +0 -0
  21. src/face3d/util/__pycache__/preprocess.cpython-38.pyc +0 -0
  22. src/facerender/__pycache__/animate.cpython-38.pyc +0 -0
  23. src/facerender/modules/__pycache__/dense_motion.cpython-38.pyc +0 -0
  24. src/facerender/modules/__pycache__/generator.cpython-38.pyc +0 -0
  25. src/facerender/modules/__pycache__/keypoint_detector.cpython-38.pyc +0 -0
  26. src/facerender/modules/__pycache__/make_animation.cpython-38.pyc +0 -0
  27. src/facerender/modules/__pycache__/mapping.cpython-38.pyc +0 -0
  28. src/facerender/modules/__pycache__/util.cpython-38.pyc +0 -0
  29. src/{src/facerender → facerender}/pirender/base_function.py +0 -0
  30. src/{src/facerender → facerender}/pirender/config.py +0 -0
  31. src/{src/facerender → facerender}/pirender/face_model.py +0 -0
  32. src/{src/facerender → facerender}/pirender_animate.py +0 -0
  33. src/facerender/sync_batchnorm/__pycache__/__init__.cpython-38.pyc +0 -0
  34. src/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc +0 -0
  35. src/facerender/sync_batchnorm/__pycache__/comm.cpython-38.pyc +0 -0
  36. src/facerender/sync_batchnorm/__pycache__/replicate.cpython-38.pyc +0 -0
  37. src/generate_facerender_batch.py +3 -4
  38. src/gradio_demo.py +21 -6
  39. src/src/audio2exp_models/audio2exp.py +0 -41
  40. src/src/audio2exp_models/networks.py +0 -74
  41. src/src/audio2pose_models/audio2pose.py +0 -94
  42. src/src/audio2pose_models/audio_encoder.py +0 -64
  43. src/src/audio2pose_models/cvae.py +0 -149
  44. src/src/audio2pose_models/discriminator.py +0 -76
  45. src/src/audio2pose_models/networks.py +0 -140
  46. src/src/audio2pose_models/res_unet.py +0 -65
  47. src/src/config/auido2exp.yaml +0 -58
  48. src/src/config/auido2pose.yaml +0 -49
  49. src/src/config/facerender.yaml +0 -45
  50. src/src/config/facerender_still.yaml +0 -45
src/__pycache__/generate_batch.cpython-38.pyc DELETED
Binary file (3.49 kB)
 
src/__pycache__/generate_facerender_batch.cpython-38.pyc DELETED
Binary file (4.06 kB)
 
src/__pycache__/test_audio2coeff.cpython-38.pyc DELETED
Binary file (3.91 kB)
 
src/audio2exp_models/__pycache__/audio2exp.cpython-38.pyc DELETED
Binary file (1.28 kB)
 
src/audio2exp_models/__pycache__/networks.cpython-38.pyc DELETED
Binary file (2.14 kB)
 
src/audio2pose_models/__pycache__/audio2pose.cpython-38.pyc DELETED
Binary file (2.86 kB)
 
src/audio2pose_models/__pycache__/audio_encoder.cpython-38.pyc DELETED
Binary file (2.17 kB)
 
src/audio2pose_models/__pycache__/cvae.cpython-38.pyc DELETED
Binary file (4.69 kB)
 
src/audio2pose_models/__pycache__/discriminator.cpython-38.pyc DELETED
Binary file (2.45 kB)
 
src/audio2pose_models/__pycache__/networks.cpython-38.pyc DELETED
Binary file (4.74 kB)
 
src/audio2pose_models/__pycache__/res_unet.cpython-38.pyc DELETED
Binary file (1.91 kB)
 
src/{src/config → config}/facerender_pirender.yaml RENAMED
File without changes
src/face3d/models/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (3.27 kB)
 
src/face3d/models/__pycache__/base_model.cpython-38.pyc DELETED
Binary file (12.5 kB)
 
src/face3d/models/__pycache__/networks.cpython-38.pyc DELETED
Binary file (17.1 kB)
 
src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (861 Bytes)
 
src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-38.pyc DELETED
Binary file (5.43 kB)
 
src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-38.pyc DELETED
Binary file (5.49 kB)
 
src/face3d/util/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (294 Bytes)
 
src/face3d/util/__pycache__/load_mats.cpython-38.pyc DELETED
Binary file (2.95 kB)
 
src/face3d/util/__pycache__/preprocess.cpython-38.pyc DELETED
Binary file (3.34 kB)
 
src/facerender/__pycache__/animate.cpython-38.pyc DELETED
Binary file (6.91 kB)
 
src/facerender/modules/__pycache__/dense_motion.cpython-38.pyc DELETED
Binary file (3.92 kB)
 
src/facerender/modules/__pycache__/generator.cpython-38.pyc DELETED
Binary file (6.59 kB)
 
src/facerender/modules/__pycache__/keypoint_detector.cpython-38.pyc DELETED
Binary file (4.83 kB)
 
src/facerender/modules/__pycache__/make_animation.cpython-38.pyc DELETED
Binary file (4.76 kB)
 
src/facerender/modules/__pycache__/mapping.cpython-38.pyc DELETED
Binary file (1.69 kB)
 
src/facerender/modules/__pycache__/util.cpython-38.pyc DELETED
Binary file (17.2 kB)
 
src/{src/facerender → facerender}/pirender/base_function.py RENAMED
File without changes
src/{src/facerender → facerender}/pirender/config.py RENAMED
File without changes
src/{src/facerender → facerender}/pirender/face_model.py RENAMED
File without changes
src/{src/facerender → facerender}/pirender_animate.py RENAMED
File without changes
src/facerender/sync_batchnorm/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (403 Bytes)
 
src/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc DELETED
Binary file (12.9 kB)
 
src/facerender/sync_batchnorm/__pycache__/comm.cpython-38.pyc DELETED
Binary file (4.84 kB)
 
src/facerender/sync_batchnorm/__pycache__/replicate.cpython-38.pyc DELETED
Binary file (3.49 kB)
 
src/generate_facerender_batch.py CHANGED
@@ -7,7 +7,7 @@ import scipy.io as scio
7
 
8
  def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path,
9
  batch_size, input_yaw_list=None, input_pitch_list=None, input_roll_list=None,
10
- expression_scale=1.0, still_mode = False, preprocess='crop', size = 256):
11
 
12
  semantic_radius = 13
13
  video_name = os.path.splitext(os.path.split(coeff_path)[-1])[0]
@@ -27,10 +27,9 @@ def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path,
27
  source_semantics_dict = scio.loadmat(first_coeff_path)
28
  generated_dict = scio.loadmat(coeff_path)
29
 
30
- if 'full' not in preprocess.lower():
31
  source_semantics = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70
32
  generated_3dmm = generated_dict['coeff_3dmm'][:,:70]
33
-
34
  else:
35
  source_semantics = source_semantics_dict['coeff_3dmm'][:1,:73] #1 70
36
  generated_3dmm = generated_dict['coeff_3dmm'][:,:70]
@@ -43,7 +42,7 @@ def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path,
43
  # target
44
  generated_3dmm[:, :64] = generated_3dmm[:, :64] * expression_scale
45
 
46
- if 'full' in preprocess.lower():
47
  generated_3dmm = np.concatenate([generated_3dmm, np.repeat(source_semantics[:,70:], generated_3dmm.shape[0], axis=0)], axis=1)
48
 
49
  if still_mode:
 
7
 
8
  def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path,
9
  batch_size, input_yaw_list=None, input_pitch_list=None, input_roll_list=None,
10
+ expression_scale=1.0, still_mode = False, preprocess='crop', size = 256, facemodel='facevid2vid'):
11
 
12
  semantic_radius = 13
13
  video_name = os.path.splitext(os.path.split(coeff_path)[-1])[0]
 
27
  source_semantics_dict = scio.loadmat(first_coeff_path)
28
  generated_dict = scio.loadmat(coeff_path)
29
 
30
+ if 'full' not in preprocess.lower() and facemodel != 'pirender':
31
  source_semantics = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70
32
  generated_3dmm = generated_dict['coeff_3dmm'][:,:70]
 
33
  else:
34
  source_semantics = source_semantics_dict['coeff_3dmm'][:1,:73] #1 70
35
  generated_3dmm = generated_dict['coeff_3dmm'][:,:70]
 
42
  # target
43
  generated_3dmm[:, :64] = generated_3dmm[:, :64] * expression_scale
44
 
45
+ if 'full' in preprocess.lower() or facemodel == 'pirender':
46
  generated_3dmm = np.concatenate([generated_3dmm, np.repeat(source_semantics[:,70:], generated_3dmm.shape[0], axis=0)], axis=1)
47
 
48
  if still_mode:
src/gradio_demo.py CHANGED
@@ -1,5 +1,6 @@
1
  import torch, uuid
2
- import os, sys, shutil
 
3
  from src.utils.preprocess import CropAndExtract
4
  from src.test_audio2coeff import Audio2Coeff
5
  from src.facerender.animate import AnimateFromCoeff
@@ -20,8 +21,10 @@ class SadTalker():
20
 
21
  def __init__(self, checkpoint_path='checkpoints', config_path='src/config', lazy_load=False):
22
 
23
- if torch.cuda.is_available() :
24
  device = "cuda"
 
 
25
  else:
26
  device = "cpu"
27
 
@@ -35,7 +38,9 @@ class SadTalker():
35
 
36
  def test(self, source_image, driven_audio, preprocess='crop',
37
  still_mode=False, use_enhancer=False, batch_size=1, size=256,
38
- pose_style = 0, exp_scale=1.0,
 
 
39
  use_ref_video = False,
40
  ref_video = None,
41
  ref_info = None,
@@ -48,7 +53,15 @@ class SadTalker():
48
 
49
  self.audio_to_coeff = Audio2Coeff(self.sadtalker_paths, self.device)
50
  self.preprocess_model = CropAndExtract(self.sadtalker_paths, self.device)
51
- self.animate_from_coeff = AnimateFromCoeff(self.sadtalker_paths, self.device)
 
 
 
 
 
 
 
 
52
 
53
  time_tag = str(uuid.uuid4())
54
  save_dir = os.path.join(result_dir, time_tag)
@@ -131,11 +144,13 @@ class SadTalker():
131
  if use_ref_video and ref_info == 'all':
132
  coeff_path = ref_video_coeff_path # self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
133
  else:
134
- batch = get_data(first_coeff_path, audio_path, self.device, ref_eyeblink_coeff_path=ref_eyeblink_coeff_path, still=still_mode, idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink) # longer audio?
 
135
  coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
136
 
137
  #coeff2video
138
- data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode, preprocess=preprocess, size=size, expression_scale = exp_scale)
 
139
  return_path = self.animate_from_coeff.generate(data, save_dir, pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None, preprocess=preprocess, img_size=size)
140
  video_name = data['video_name']
141
  print(f'The generated video is named {video_name} in {save_dir}')
 
1
  import torch, uuid
2
+ import os, sys, shutil, platform
3
+ from src.facerender.pirender_animate import AnimateFromCoeff_PIRender
4
  from src.utils.preprocess import CropAndExtract
5
  from src.test_audio2coeff import Audio2Coeff
6
  from src.facerender.animate import AnimateFromCoeff
 
21
 
22
  def __init__(self, checkpoint_path='checkpoints', config_path='src/config', lazy_load=False):
23
 
24
+ if torch.cuda.is_available():
25
  device = "cuda"
26
+ elif platform.system() == 'Darwin': # macos
27
+ device = "mps"
28
  else:
29
  device = "cpu"
30
 
 
38
 
39
  def test(self, source_image, driven_audio, preprocess='crop',
40
  still_mode=False, use_enhancer=False, batch_size=1, size=256,
41
+ pose_style = 0,
42
+ facerender='facevid2vid',
43
+ exp_scale=1.0,
44
  use_ref_video = False,
45
  ref_video = None,
46
  ref_info = None,
 
53
 
54
  self.audio_to_coeff = Audio2Coeff(self.sadtalker_paths, self.device)
55
  self.preprocess_model = CropAndExtract(self.sadtalker_paths, self.device)
56
+
57
+ if facerender == 'facevid2vid' and self.device != 'mps':
58
+ self.animate_from_coeff = AnimateFromCoeff(self.sadtalker_paths, self.device)
59
+ elif facerender == 'pirender' or self.device == 'mps':
60
+ self.animate_from_coeff = AnimateFromCoeff_PIRender(self.sadtalker_paths, self.device)
61
+ facerender = 'pirender'
62
+ else:
63
+ raise(RuntimeError('Unknown model: {}'.format(facerender)))
64
+
65
 
66
  time_tag = str(uuid.uuid4())
67
  save_dir = os.path.join(result_dir, time_tag)
 
144
  if use_ref_video and ref_info == 'all':
145
  coeff_path = ref_video_coeff_path # self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
146
  else:
147
+ batch = get_data(first_coeff_path, audio_path, self.device, ref_eyeblink_coeff_path=ref_eyeblink_coeff_path, still=still_mode, \
148
+ idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink) # longer audio?
149
  coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
150
 
151
  #coeff2video
152
+ data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode, \
153
+ preprocess=preprocess, size=size, expression_scale = exp_scale, facemodel=facerender)
154
  return_path = self.animate_from_coeff.generate(data, save_dir, pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None, preprocess=preprocess, img_size=size)
155
  video_name = data['video_name']
156
  print(f'The generated video is named {video_name} in {save_dir}')
src/src/audio2exp_models/audio2exp.py DELETED
@@ -1,41 +0,0 @@
1
- from tqdm import tqdm
2
- import torch
3
- from torch import nn
4
-
5
-
6
- class Audio2Exp(nn.Module):
7
- def __init__(self, netG, cfg, device, prepare_training_loss=False):
8
- super(Audio2Exp, self).__init__()
9
- self.cfg = cfg
10
- self.device = device
11
- self.netG = netG.to(device)
12
-
13
- def test(self, batch):
14
-
15
- mel_input = batch['indiv_mels'] # bs T 1 80 16
16
- bs = mel_input.shape[0]
17
- T = mel_input.shape[1]
18
-
19
- exp_coeff_pred = []
20
-
21
- for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames
22
-
23
- current_mel_input = mel_input[:,i:i+10]
24
-
25
- #ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64
26
- ref = batch['ref'][:, :, :64][:, i:i+10]
27
- ratio = batch['ratio_gt'][:, i:i+10] #bs T
28
-
29
- audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16
30
-
31
- curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64
32
-
33
- exp_coeff_pred += [curr_exp_coeff_pred]
34
-
35
- # BS x T x 64
36
- results_dict = {
37
- 'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1)
38
- }
39
- return results_dict
40
-
41
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/src/audio2exp_models/networks.py DELETED
@@ -1,74 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from torch import nn
4
-
5
- class Conv2d(nn.Module):
6
- def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
7
- super().__init__(*args, **kwargs)
8
- self.conv_block = nn.Sequential(
9
- nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
- nn.BatchNorm2d(cout)
11
- )
12
- self.act = nn.ReLU()
13
- self.residual = residual
14
- self.use_act = use_act
15
-
16
- def forward(self, x):
17
- out = self.conv_block(x)
18
- if self.residual:
19
- out += x
20
-
21
- if self.use_act:
22
- return self.act(out)
23
- else:
24
- return out
25
-
26
- class SimpleWrapperV2(nn.Module):
27
- def __init__(self) -> None:
28
- super().__init__()
29
- self.audio_encoder = nn.Sequential(
30
- Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
31
- Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
32
- Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
33
-
34
- Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
35
- Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
36
- Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
37
-
38
- Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
39
- Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
40
- Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
41
-
42
- Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
43
- Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
44
-
45
- Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
46
- Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
47
- )
48
-
49
- #### load the pre-trained audio_encoder
50
- #self.audio_encoder = self.audio_encoder.to(device)
51
- '''
52
- wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
53
- state_dict = self.audio_encoder.state_dict()
54
-
55
- for k,v in wav2lip_state_dict.items():
56
- if 'audio_encoder' in k:
57
- print('init:', k)
58
- state_dict[k.replace('module.audio_encoder.', '')] = v
59
- self.audio_encoder.load_state_dict(state_dict)
60
- '''
61
-
62
- self.mapping1 = nn.Linear(512+64+1, 64)
63
- #self.mapping2 = nn.Linear(30, 64)
64
- #nn.init.constant_(self.mapping1.weight, 0.)
65
- nn.init.constant_(self.mapping1.bias, 0.)
66
-
67
- def forward(self, x, ref, ratio):
68
- x = self.audio_encoder(x).view(x.size(0), -1)
69
- ref_reshape = ref.reshape(x.size(0), -1)
70
- ratio = ratio.reshape(x.size(0), -1)
71
-
72
- y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1))
73
- out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
74
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/src/audio2pose_models/audio2pose.py DELETED
@@ -1,94 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from src.audio2pose_models.cvae import CVAE
4
- from src.audio2pose_models.discriminator import PoseSequenceDiscriminator
5
- from src.audio2pose_models.audio_encoder import AudioEncoder
6
-
7
- class Audio2Pose(nn.Module):
8
- def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
9
- super().__init__()
10
- self.cfg = cfg
11
- self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
12
- self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
13
- self.device = device
14
-
15
- self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device)
16
- self.audio_encoder.eval()
17
- for param in self.audio_encoder.parameters():
18
- param.requires_grad = False
19
-
20
- self.netG = CVAE(cfg)
21
- self.netD_motion = PoseSequenceDiscriminator(cfg)
22
-
23
-
24
- def forward(self, x):
25
-
26
- batch = {}
27
- coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
28
- batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6
29
- batch['ref'] = coeff_gt[:, 0, 64:70] #bs 6
30
- batch['class'] = x['class'].squeeze(0).cuda() # bs
31
- indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
32
-
33
- # forward
34
- audio_emb_list = []
35
- audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
36
- batch['audio_emb'] = audio_emb
37
- batch = self.netG(batch)
38
-
39
- pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
40
- pose_gt = coeff_gt[:, 1:, 64:70].clone() # bs frame_len 6
41
- pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred # bs frame_len 6
42
-
43
- batch['pose_pred'] = pose_pred
44
- batch['pose_gt'] = pose_gt
45
-
46
- return batch
47
-
48
- def test(self, x):
49
-
50
- batch = {}
51
- ref = x['ref'] #bs 1 70
52
- batch['ref'] = x['ref'][:,0,-6:]
53
- batch['class'] = x['class']
54
- bs = ref.shape[0]
55
-
56
- indiv_mels= x['indiv_mels'] # bs T 1 80 16
57
- indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame
58
- num_frames = x['num_frames']
59
- num_frames = int(num_frames) - 1
60
-
61
- #
62
- div = num_frames//self.seq_len
63
- re = num_frames%self.seq_len
64
- audio_emb_list = []
65
- pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype,
66
- device=batch['ref'].device)]
67
-
68
- for i in range(div):
69
- z = torch.randn(bs, self.latent_dim).to(ref.device)
70
- batch['z'] = z
71
- audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
72
- batch['audio_emb'] = audio_emb
73
- batch = self.netG.test(batch)
74
- pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6
75
-
76
- if re != 0:
77
- z = torch.randn(bs, self.latent_dim).to(ref.device)
78
- batch['z'] = z
79
- audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512
80
- if audio_emb.shape[1] != self.seq_len:
81
- pad_dim = self.seq_len-audio_emb.shape[1]
82
- pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1)
83
- audio_emb = torch.cat([pad_audio_emb, audio_emb], 1)
84
- batch['audio_emb'] = audio_emb
85
- batch = self.netG.test(batch)
86
- pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])
87
-
88
- pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
89
- batch['pose_motion_pred'] = pose_motion_pred
90
-
91
- pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6
92
-
93
- batch['pose_pred'] = pose_pred
94
- return batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/src/audio2pose_models/audio_encoder.py DELETED
@@ -1,64 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn import functional as F
4
-
5
- class Conv2d(nn.Module):
6
- def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
7
- super().__init__(*args, **kwargs)
8
- self.conv_block = nn.Sequential(
9
- nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
- nn.BatchNorm2d(cout)
11
- )
12
- self.act = nn.ReLU()
13
- self.residual = residual
14
-
15
- def forward(self, x):
16
- out = self.conv_block(x)
17
- if self.residual:
18
- out += x
19
- return self.act(out)
20
-
21
- class AudioEncoder(nn.Module):
22
- def __init__(self, wav2lip_checkpoint, device):
23
- super(AudioEncoder, self).__init__()
24
-
25
- self.audio_encoder = nn.Sequential(
26
- Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
27
- Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
28
- Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
29
-
30
- Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
31
- Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
32
- Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
33
-
34
- Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
35
- Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
36
- Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
37
-
38
- Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
39
- Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
40
-
41
- Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
42
- Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
43
-
44
- #### load the pre-trained audio_encoder, we do not need to load wav2lip model here.
45
- # wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
46
- # state_dict = self.audio_encoder.state_dict()
47
-
48
- # for k,v in wav2lip_state_dict.items():
49
- # if 'audio_encoder' in k:
50
- # state_dict[k.replace('module.audio_encoder.', '')] = v
51
- # self.audio_encoder.load_state_dict(state_dict)
52
-
53
-
54
- def forward(self, audio_sequences):
55
- # audio_sequences = (B, T, 1, 80, 16)
56
- B = audio_sequences.size(0)
57
-
58
- audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
59
-
60
- audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
61
- dim = audio_embedding.shape[1]
62
- audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
63
-
64
- return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/src/audio2pose_models/cvae.py DELETED
@@ -1,149 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from torch import nn
4
- from src.audio2pose_models.res_unet import ResUnet
5
-
6
- def class2onehot(idx, class_num):
7
-
8
- assert torch.max(idx).item() < class_num
9
- onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
10
- onehot.scatter_(1, idx, 1)
11
- return onehot
12
-
13
- class CVAE(nn.Module):
14
- def __init__(self, cfg):
15
- super().__init__()
16
- encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
17
- decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
18
- latent_size = cfg.MODEL.CVAE.LATENT_SIZE
19
- num_classes = cfg.DATASET.NUM_CLASSES
20
- audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
21
- audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
22
- seq_len = cfg.MODEL.CVAE.SEQ_LEN
23
-
24
- self.latent_size = latent_size
25
-
26
- self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
27
- audio_emb_in_size, audio_emb_out_size, seq_len)
28
- self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
29
- audio_emb_in_size, audio_emb_out_size, seq_len)
30
- def reparameterize(self, mu, logvar):
31
- std = torch.exp(0.5 * logvar)
32
- eps = torch.randn_like(std)
33
- return mu + eps * std
34
-
35
- def forward(self, batch):
36
- batch = self.encoder(batch)
37
- mu = batch['mu']
38
- logvar = batch['logvar']
39
- z = self.reparameterize(mu, logvar)
40
- batch['z'] = z
41
- return self.decoder(batch)
42
-
43
- def test(self, batch):
44
- '''
45
- class_id = batch['class']
46
- z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
47
- batch['z'] = z
48
- '''
49
- return self.decoder(batch)
50
-
51
- class ENCODER(nn.Module):
52
- def __init__(self, layer_sizes, latent_size, num_classes,
53
- audio_emb_in_size, audio_emb_out_size, seq_len):
54
- super().__init__()
55
-
56
- self.resunet = ResUnet()
57
- self.num_classes = num_classes
58
- self.seq_len = seq_len
59
-
60
- self.MLP = nn.Sequential()
61
- layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
62
- for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
63
- self.MLP.add_module(
64
- name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
65
- self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
66
-
67
- self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
68
- self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
69
- self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
70
-
71
- self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
72
-
73
- def forward(self, batch):
74
- class_id = batch['class']
75
- pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6
76
- ref = batch['ref'] #bs 6
77
- bs = pose_motion_gt.shape[0]
78
- audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
79
-
80
- #pose encode
81
- pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6
82
- pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6
83
-
84
- #audio mapping
85
- print(audio_in.shape)
86
- audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
87
- audio_out = audio_out.reshape(bs, -1)
88
-
89
- class_bias = self.classbias[class_id] #bs latent_size
90
- x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size
91
- x_out = self.MLP(x_in)
92
-
93
- mu = self.linear_means(x_out)
94
- logvar = self.linear_means(x_out) #bs latent_size
95
-
96
- batch.update({'mu':mu, 'logvar':logvar})
97
- return batch
98
-
99
- class DECODER(nn.Module):
100
- def __init__(self, layer_sizes, latent_size, num_classes,
101
- audio_emb_in_size, audio_emb_out_size, seq_len):
102
- super().__init__()
103
-
104
- self.resunet = ResUnet()
105
- self.num_classes = num_classes
106
- self.seq_len = seq_len
107
-
108
- self.MLP = nn.Sequential()
109
- input_size = latent_size + seq_len*audio_emb_out_size + 6
110
- for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
111
- self.MLP.add_module(
112
- name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
113
- if i+1 < len(layer_sizes):
114
- self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
115
- else:
116
- self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
117
-
118
- self.pose_linear = nn.Linear(6, 6)
119
- self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
120
-
121
- self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
122
-
123
- def forward(self, batch):
124
-
125
- z = batch['z'] #bs latent_size
126
- bs = z.shape[0]
127
- class_id = batch['class']
128
- ref = batch['ref'] #bs 6
129
- audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
130
- #print('audio_in: ', audio_in[:, :, :10])
131
-
132
- audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
133
- #print('audio_out: ', audio_out[:, :, :10])
134
- audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size
135
- class_bias = self.classbias[class_id] #bs latent_size
136
-
137
- z = z + class_bias
138
- x_in = torch.cat([ref, z, audio_out], dim=-1)
139
- x_out = self.MLP(x_in) # bs layer_sizes[-1]
140
- x_out = x_out.reshape((bs, self.seq_len, -1))
141
-
142
- #print('x_out: ', x_out)
143
-
144
- pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6
145
-
146
- pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6
147
-
148
- batch.update({'pose_motion_pred':pose_motion_pred})
149
- return batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/src/audio2pose_models/discriminator.py DELETED
@@ -1,76 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from torch import nn
4
-
5
- class ConvNormRelu(nn.Module):
6
- def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,
7
- kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):
8
- super().__init__()
9
- if kernel_size is None:
10
- if downsample:
11
- kernel_size, stride, padding = 4, 2, 1
12
- else:
13
- kernel_size, stride, padding = 3, 1, 1
14
-
15
- if conv_type == '2d':
16
- self.conv = nn.Conv2d(
17
- in_channels,
18
- out_channels,
19
- kernel_size,
20
- stride,
21
- padding,
22
- bias=False,
23
- )
24
- if norm == 'BN':
25
- self.norm = nn.BatchNorm2d(out_channels)
26
- elif norm == 'IN':
27
- self.norm = nn.InstanceNorm2d(out_channels)
28
- else:
29
- raise NotImplementedError
30
- elif conv_type == '1d':
31
- self.conv = nn.Conv1d(
32
- in_channels,
33
- out_channels,
34
- kernel_size,
35
- stride,
36
- padding,
37
- bias=False,
38
- )
39
- if norm == 'BN':
40
- self.norm = nn.BatchNorm1d(out_channels)
41
- elif norm == 'IN':
42
- self.norm = nn.InstanceNorm1d(out_channels)
43
- else:
44
- raise NotImplementedError
45
- nn.init.kaiming_normal_(self.conv.weight)
46
-
47
- self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)
48
-
49
- def forward(self, x):
50
- x = self.conv(x)
51
- if isinstance(self.norm, nn.InstanceNorm1d):
52
- x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C]
53
- else:
54
- x = self.norm(x)
55
- x = self.act(x)
56
- return x
57
-
58
-
59
- class PoseSequenceDiscriminator(nn.Module):
60
- def __init__(self, cfg):
61
- super().__init__()
62
- self.cfg = cfg
63
- leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU
64
-
65
- self.seq = nn.Sequential(
66
- ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64
67
- ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32
68
- ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16
69
- nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16
70
- )
71
-
72
- def forward(self, x):
73
- x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
74
- x = self.seq(x)
75
- x = x.squeeze(1)
76
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/src/audio2pose_models/networks.py DELETED
@@ -1,140 +0,0 @@
1
- import torch.nn as nn
2
- import torch
3
-
4
-
5
- class ResidualConv(nn.Module):
6
- def __init__(self, input_dim, output_dim, stride, padding):
7
- super(ResidualConv, self).__init__()
8
-
9
- self.conv_block = nn.Sequential(
10
- nn.BatchNorm2d(input_dim),
11
- nn.ReLU(),
12
- nn.Conv2d(
13
- input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
14
- ),
15
- nn.BatchNorm2d(output_dim),
16
- nn.ReLU(),
17
- nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
18
- )
19
- self.conv_skip = nn.Sequential(
20
- nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
21
- nn.BatchNorm2d(output_dim),
22
- )
23
-
24
- def forward(self, x):
25
-
26
- return self.conv_block(x) + self.conv_skip(x)
27
-
28
-
29
- class Upsample(nn.Module):
30
- def __init__(self, input_dim, output_dim, kernel, stride):
31
- super(Upsample, self).__init__()
32
-
33
- self.upsample = nn.ConvTranspose2d(
34
- input_dim, output_dim, kernel_size=kernel, stride=stride
35
- )
36
-
37
- def forward(self, x):
38
- return self.upsample(x)
39
-
40
-
41
- class Squeeze_Excite_Block(nn.Module):
42
- def __init__(self, channel, reduction=16):
43
- super(Squeeze_Excite_Block, self).__init__()
44
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
45
- self.fc = nn.Sequential(
46
- nn.Linear(channel, channel // reduction, bias=False),
47
- nn.ReLU(inplace=True),
48
- nn.Linear(channel // reduction, channel, bias=False),
49
- nn.Sigmoid(),
50
- )
51
-
52
- def forward(self, x):
53
- b, c, _, _ = x.size()
54
- y = self.avg_pool(x).view(b, c)
55
- y = self.fc(y).view(b, c, 1, 1)
56
- return x * y.expand_as(x)
57
-
58
-
59
- class ASPP(nn.Module):
60
- def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
61
- super(ASPP, self).__init__()
62
-
63
- self.aspp_block1 = nn.Sequential(
64
- nn.Conv2d(
65
- in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
66
- ),
67
- nn.ReLU(inplace=True),
68
- nn.BatchNorm2d(out_dims),
69
- )
70
- self.aspp_block2 = nn.Sequential(
71
- nn.Conv2d(
72
- in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
73
- ),
74
- nn.ReLU(inplace=True),
75
- nn.BatchNorm2d(out_dims),
76
- )
77
- self.aspp_block3 = nn.Sequential(
78
- nn.Conv2d(
79
- in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
80
- ),
81
- nn.ReLU(inplace=True),
82
- nn.BatchNorm2d(out_dims),
83
- )
84
-
85
- self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
86
- self._init_weights()
87
-
88
- def forward(self, x):
89
- x1 = self.aspp_block1(x)
90
- x2 = self.aspp_block2(x)
91
- x3 = self.aspp_block3(x)
92
- out = torch.cat([x1, x2, x3], dim=1)
93
- return self.output(out)
94
-
95
- def _init_weights(self):
96
- for m in self.modules():
97
- if isinstance(m, nn.Conv2d):
98
- nn.init.kaiming_normal_(m.weight)
99
- elif isinstance(m, nn.BatchNorm2d):
100
- m.weight.data.fill_(1)
101
- m.bias.data.zero_()
102
-
103
-
104
- class Upsample_(nn.Module):
105
- def __init__(self, scale=2):
106
- super(Upsample_, self).__init__()
107
-
108
- self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
109
-
110
- def forward(self, x):
111
- return self.upsample(x)
112
-
113
-
114
- class AttentionBlock(nn.Module):
115
- def __init__(self, input_encoder, input_decoder, output_dim):
116
- super(AttentionBlock, self).__init__()
117
-
118
- self.conv_encoder = nn.Sequential(
119
- nn.BatchNorm2d(input_encoder),
120
- nn.ReLU(),
121
- nn.Conv2d(input_encoder, output_dim, 3, padding=1),
122
- nn.MaxPool2d(2, 2),
123
- )
124
-
125
- self.conv_decoder = nn.Sequential(
126
- nn.BatchNorm2d(input_decoder),
127
- nn.ReLU(),
128
- nn.Conv2d(input_decoder, output_dim, 3, padding=1),
129
- )
130
-
131
- self.conv_attn = nn.Sequential(
132
- nn.BatchNorm2d(output_dim),
133
- nn.ReLU(),
134
- nn.Conv2d(output_dim, 1, 1),
135
- )
136
-
137
- def forward(self, x1, x2):
138
- out = self.conv_encoder(x1) + self.conv_decoder(x2)
139
- out = self.conv_attn(out)
140
- return out * x2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/src/audio2pose_models/res_unet.py DELETED
@@ -1,65 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from src.audio2pose_models.networks import ResidualConv, Upsample
4
-
5
-
6
- class ResUnet(nn.Module):
7
- def __init__(self, channel=1, filters=[32, 64, 128, 256]):
8
- super(ResUnet, self).__init__()
9
-
10
- self.input_layer = nn.Sequential(
11
- nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
12
- nn.BatchNorm2d(filters[0]),
13
- nn.ReLU(),
14
- nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
15
- )
16
- self.input_skip = nn.Sequential(
17
- nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
18
- )
19
-
20
- self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1)
21
- self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1)
22
-
23
- self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1)
24
-
25
- self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1))
26
- self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1)
27
-
28
- self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1))
29
- self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1)
30
-
31
- self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1))
32
- self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1)
33
-
34
- self.output_layer = nn.Sequential(
35
- nn.Conv2d(filters[0], 1, 1, 1),
36
- nn.Sigmoid(),
37
- )
38
-
39
- def forward(self, x):
40
- # Encode
41
- x1 = self.input_layer(x) + self.input_skip(x)
42
- x2 = self.residual_conv_1(x1)
43
- x3 = self.residual_conv_2(x2)
44
- # Bridge
45
- x4 = self.bridge(x3)
46
-
47
- # Decode
48
- x4 = self.upsample_1(x4)
49
- x5 = torch.cat([x4, x3], dim=1)
50
-
51
- x6 = self.up_residual_conv1(x5)
52
-
53
- x6 = self.upsample_2(x6)
54
- x7 = torch.cat([x6, x2], dim=1)
55
-
56
- x8 = self.up_residual_conv2(x7)
57
-
58
- x8 = self.upsample_3(x8)
59
- x9 = torch.cat([x8, x1], dim=1)
60
-
61
- x10 = self.up_residual_conv3(x9)
62
-
63
- output = self.output_layer(x10)
64
-
65
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/src/config/auido2exp.yaml DELETED
@@ -1,58 +0,0 @@
1
- DATASET:
2
- TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt
3
- EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt
4
- TRAIN_BATCH_SIZE: 32
5
- EVAL_BATCH_SIZE: 32
6
- EXP: True
7
- EXP_DIM: 64
8
- FRAME_LEN: 32
9
- COEFF_LEN: 73
10
- NUM_CLASSES: 46
11
- AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
12
- COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm
13
- LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
14
- DEBUG: True
15
- NUM_REPEATS: 2
16
- T: 40
17
-
18
-
19
- MODEL:
20
- FRAMEWORK: V2
21
- AUDIOENCODER:
22
- LEAKY_RELU: True
23
- NORM: 'IN'
24
- DISCRIMINATOR:
25
- LEAKY_RELU: False
26
- INPUT_CHANNELS: 6
27
- CVAE:
28
- AUDIO_EMB_IN_SIZE: 512
29
- AUDIO_EMB_OUT_SIZE: 128
30
- SEQ_LEN: 32
31
- LATENT_SIZE: 256
32
- ENCODER_LAYER_SIZES: [192, 1024]
33
- DECODER_LAYER_SIZES: [1024, 192]
34
-
35
-
36
- TRAIN:
37
- MAX_EPOCH: 300
38
- GENERATOR:
39
- LR: 2.0e-5
40
- DISCRIMINATOR:
41
- LR: 1.0e-5
42
- LOSS:
43
- W_FEAT: 0
44
- W_COEFF_EXP: 2
45
- W_LM: 1.0e-2
46
- W_LM_MOUTH: 0
47
- W_REG: 0
48
- W_SYNC: 0
49
- W_COLOR: 0
50
- W_EXPRESSION: 0
51
- W_LIPREADING: 0.01
52
- W_LIPREADING_VV: 0
53
- W_EYE_BLINK: 4
54
-
55
- TAG:
56
- NAME: small_dataset
57
-
58
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/src/config/auido2pose.yaml DELETED
@@ -1,49 +0,0 @@
1
- DATASET:
2
- TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt
3
- EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt
4
- TRAIN_BATCH_SIZE: 64
5
- EVAL_BATCH_SIZE: 1
6
- EXP: True
7
- EXP_DIM: 64
8
- FRAME_LEN: 32
9
- COEFF_LEN: 73
10
- NUM_CLASSES: 46
11
- AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
12
- COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
13
- DEBUG: True
14
-
15
-
16
- MODEL:
17
- AUDIOENCODER:
18
- LEAKY_RELU: True
19
- NORM: 'IN'
20
- DISCRIMINATOR:
21
- LEAKY_RELU: False
22
- INPUT_CHANNELS: 6
23
- CVAE:
24
- AUDIO_EMB_IN_SIZE: 512
25
- AUDIO_EMB_OUT_SIZE: 6
26
- SEQ_LEN: 32
27
- LATENT_SIZE: 64
28
- ENCODER_LAYER_SIZES: [192, 128]
29
- DECODER_LAYER_SIZES: [128, 192]
30
-
31
-
32
- TRAIN:
33
- MAX_EPOCH: 150
34
- GENERATOR:
35
- LR: 1.0e-4
36
- DISCRIMINATOR:
37
- LR: 1.0e-4
38
- LOSS:
39
- LAMBDA_REG: 1
40
- LAMBDA_LANDMARKS: 0
41
- LAMBDA_VERTICES: 0
42
- LAMBDA_GAN_MOTION: 0.7
43
- LAMBDA_GAN_COEFF: 0
44
- LAMBDA_KL: 1
45
-
46
- TAG:
47
- NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder
48
-
49
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/src/config/facerender.yaml DELETED
@@ -1,45 +0,0 @@
1
- model_params:
2
- common_params:
3
- num_kp: 15
4
- image_channel: 3
5
- feature_channel: 32
6
- estimate_jacobian: False # True
7
- kp_detector_params:
8
- temperature: 0.1
9
- block_expansion: 32
10
- max_features: 1024
11
- scale_factor: 0.25 # 0.25
12
- num_blocks: 5
13
- reshape_channel: 16384 # 16384 = 1024 * 16
14
- reshape_depth: 16
15
- he_estimator_params:
16
- block_expansion: 64
17
- max_features: 2048
18
- num_bins: 66
19
- generator_params:
20
- block_expansion: 64
21
- max_features: 512
22
- num_down_blocks: 2
23
- reshape_channel: 32
24
- reshape_depth: 16 # 512 = 32 * 16
25
- num_resblocks: 6
26
- estimate_occlusion_map: True
27
- dense_motion_params:
28
- block_expansion: 32
29
- max_features: 1024
30
- num_blocks: 5
31
- reshape_depth: 16
32
- compress: 4
33
- discriminator_params:
34
- scales: [1]
35
- block_expansion: 32
36
- max_features: 512
37
- num_blocks: 4
38
- sn: True
39
- mapping_params:
40
- coeff_nc: 70
41
- descriptor_nc: 1024
42
- layer: 3
43
- num_kp: 15
44
- num_bins: 66
45
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/src/config/facerender_still.yaml DELETED
@@ -1,45 +0,0 @@
1
- model_params:
2
- common_params:
3
- num_kp: 15
4
- image_channel: 3
5
- feature_channel: 32
6
- estimate_jacobian: False # True
7
- kp_detector_params:
8
- temperature: 0.1
9
- block_expansion: 32
10
- max_features: 1024
11
- scale_factor: 0.25 # 0.25
12
- num_blocks: 5
13
- reshape_channel: 16384 # 16384 = 1024 * 16
14
- reshape_depth: 16
15
- he_estimator_params:
16
- block_expansion: 64
17
- max_features: 2048
18
- num_bins: 66
19
- generator_params:
20
- block_expansion: 64
21
- max_features: 512
22
- num_down_blocks: 2
23
- reshape_channel: 32
24
- reshape_depth: 16 # 512 = 32 * 16
25
- num_resblocks: 6
26
- estimate_occlusion_map: True
27
- dense_motion_params:
28
- block_expansion: 32
29
- max_features: 1024
30
- num_blocks: 5
31
- reshape_depth: 16
32
- compress: 4
33
- discriminator_params:
34
- scales: [1]
35
- block_expansion: 32
36
- max_features: 512
37
- num_blocks: 4
38
- sn: True
39
- mapping_params:
40
- coeff_nc: 73
41
- descriptor_nc: 1024
42
- layer: 3
43
- num_kp: 15
44
- num_bins: 66
45
-