FredWe commited on
Commit
88bbcbf
1 Parent(s): ad87fd0

Upload 36 files

Browse files
Files changed (37) hide show
  1. .gitattributes +1 -0
  2. data_utils/deepspeech_features/README.md +20 -0
  3. data_utils/deepspeech_features/deepspeech_features.py +275 -0
  4. data_utils/deepspeech_features/deepspeech_store.py +172 -0
  5. data_utils/deepspeech_features/extract_ds_features.py +132 -0
  6. data_utils/deepspeech_features/extract_wav.py +87 -0
  7. data_utils/deepspeech_features/fea_win.py +11 -0
  8. data_utils/face_parsing/79999_iter.pth +3 -0
  9. data_utils/face_parsing/logger.py +23 -0
  10. data_utils/face_parsing/model.py +285 -0
  11. data_utils/face_parsing/resnet.py +109 -0
  12. data_utils/face_parsing/test.py +98 -0
  13. data_utils/face_tracking/.DS_Store +0 -0
  14. data_utils/face_tracking/3DMM/.DS_Store +0 -0
  15. data_utils/face_tracking/3DMM/01_MorphableModel.mat +3 -0
  16. data_utils/face_tracking/3DMM/01_MorphableModel.mat.zip +3 -0
  17. data_utils/face_tracking/3DMM/exp_info.npy +3 -0
  18. data_utils/face_tracking/3DMM/keys_info.npy +3 -0
  19. data_utils/face_tracking/3DMM/sub_mesh.obj +0 -0
  20. data_utils/face_tracking/3DMM/topology_info.npy +3 -0
  21. data_utils/face_tracking/__init__.py +0 -0
  22. data_utils/face_tracking/convert_BFM.py +39 -0
  23. data_utils/face_tracking/data_loader.py +16 -0
  24. data_utils/face_tracking/face_tracker.py +390 -0
  25. data_utils/face_tracking/facemodel.py +153 -0
  26. data_utils/face_tracking/geo_transform.py +69 -0
  27. data_utils/face_tracking/render_3dmm.py +202 -0
  28. data_utils/face_tracking/render_land.py +192 -0
  29. data_utils/face_tracking/util.py +109 -0
  30. data_utils/hubert.py +92 -0
  31. data_utils/process.py +405 -0
  32. data_utils/wav2mel.py +167 -0
  33. data_utils/wav2mel_hparams.py +80 -0
  34. data_utils/wav2vec.py +420 -0
  35. tensorflow-models/deepspeech-0_1_0-b90017e8.pb.zip +3 -0
  36. torch-hub/2DFAN4-cd938726ad.zip +3 -0
  37. torch-hub/s3fd-619a316812.pth +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data_utils/face_tracking/3DMM/01_MorphableModel.mat filter=lfs diff=lfs merge=lfs -text
data_utils/deepspeech_features/README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Routines for DeepSpeech features processing
2
+ Several routines for [DeepSpeech](https://github.com/mozilla/DeepSpeech) features processing, like speech features generation for [VOCA](https://github.com/TimoBolkart/voca) model.
3
+
4
+ ## Installation
5
+
6
+ ```
7
+ pip3 install -r requirements.txt
8
+ ```
9
+
10
+ ## Usage
11
+
12
+ Generate wav files:
13
+ ```
14
+ python3 extract_wav.py --in-video=<you_data_dir>
15
+ ```
16
+
17
+ Generate files with DeepSpeech features:
18
+ ```
19
+ python3 extract_ds_features.py --input=<you_data_dir>
20
+ ```
data_utils/deepspeech_features/deepspeech_features.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DeepSpeech features processing routines.
3
+ NB: Based on VOCA code. See the corresponding license restrictions.
4
+ """
5
+
6
+ __all__ = ['conv_audios_to_deepspeech']
7
+
8
+ import numpy as np
9
+ import warnings
10
+ import resampy
11
+ from scipy.io import wavfile
12
+ from python_speech_features import mfcc
13
+ import tensorflow.compat.v1 as tf
14
+ tf.disable_v2_behavior()
15
+
16
+ def conv_audios_to_deepspeech(audios,
17
+ out_files,
18
+ num_frames_info,
19
+ deepspeech_pb_path,
20
+ audio_window_size=1,
21
+ audio_window_stride=1):
22
+ """
23
+ Convert list of audio files into files with DeepSpeech features.
24
+
25
+ Parameters
26
+ ----------
27
+ audios : list of str or list of None
28
+ Paths to input audio files.
29
+ out_files : list of str
30
+ Paths to output files with DeepSpeech features.
31
+ num_frames_info : list of int
32
+ List of numbers of frames.
33
+ deepspeech_pb_path : str
34
+ Path to DeepSpeech 0.1.0 frozen model.
35
+ audio_window_size : int, default 16
36
+ Audio window size.
37
+ audio_window_stride : int, default 1
38
+ Audio window stride.
39
+ """
40
+ # deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
41
+ graph, logits_ph, input_node_ph, input_lengths_ph = prepare_deepspeech_net(
42
+ deepspeech_pb_path)
43
+
44
+ with tf.compat.v1.Session(graph=graph) as sess:
45
+ for audio_file_path, out_file_path, num_frames in zip(audios, out_files, num_frames_info):
46
+ print(audio_file_path)
47
+ print(out_file_path)
48
+ audio_sample_rate, audio = wavfile.read(audio_file_path)
49
+ if audio.ndim != 1:
50
+ warnings.warn(
51
+ "Audio has multiple channels, the first channel is used")
52
+ audio = audio[:, 0]
53
+ ds_features = pure_conv_audio_to_deepspeech(
54
+ audio=audio,
55
+ audio_sample_rate=audio_sample_rate,
56
+ audio_window_size=audio_window_size,
57
+ audio_window_stride=audio_window_stride,
58
+ num_frames=num_frames,
59
+ net_fn=lambda x: sess.run(
60
+ logits_ph,
61
+ feed_dict={
62
+ input_node_ph: x[np.newaxis, ...],
63
+ input_lengths_ph: [x.shape[0]]}))
64
+
65
+ net_output = ds_features.reshape(-1, 29)
66
+ win_size = 16
67
+ zero_pad = np.zeros((int(win_size / 2), net_output.shape[1]))
68
+ net_output = np.concatenate(
69
+ (zero_pad, net_output, zero_pad), axis=0)
70
+ windows = []
71
+ for window_index in range(0, net_output.shape[0] - win_size, 2):
72
+ windows.append(
73
+ net_output[window_index:window_index + win_size])
74
+ print(np.array(windows).shape)
75
+ np.save(out_file_path, np.array(windows))
76
+
77
+
78
+ def prepare_deepspeech_net(deepspeech_pb_path):
79
+ """
80
+ Load and prepare DeepSpeech network.
81
+
82
+ Parameters
83
+ ----------
84
+ deepspeech_pb_path : str
85
+ Path to DeepSpeech 0.1.0 frozen model.
86
+
87
+ Returns
88
+ -------
89
+ graph : obj
90
+ ThensorFlow graph.
91
+ logits_ph : obj
92
+ ThensorFlow placeholder for `logits`.
93
+ input_node_ph : obj
94
+ ThensorFlow placeholder for `input_node`.
95
+ input_lengths_ph : obj
96
+ ThensorFlow placeholder for `input_lengths`.
97
+ """
98
+ # Load graph and place_holders:
99
+ with tf.io.gfile.GFile(deepspeech_pb_path, "rb") as f:
100
+ graph_def = tf.compat.v1.GraphDef()
101
+ graph_def.ParseFromString(f.read())
102
+
103
+ graph = tf.compat.v1.get_default_graph()
104
+ tf.import_graph_def(graph_def, name="deepspeech")
105
+ logits_ph = graph.get_tensor_by_name("deepspeech/logits:0")
106
+ input_node_ph = graph.get_tensor_by_name("deepspeech/input_node:0")
107
+ input_lengths_ph = graph.get_tensor_by_name("deepspeech/input_lengths:0")
108
+
109
+ return graph, logits_ph, input_node_ph, input_lengths_ph
110
+
111
+
112
+ def pure_conv_audio_to_deepspeech(audio,
113
+ audio_sample_rate,
114
+ audio_window_size,
115
+ audio_window_stride,
116
+ num_frames,
117
+ net_fn):
118
+ """
119
+ Core routine for converting audion into DeepSpeech features.
120
+
121
+ Parameters
122
+ ----------
123
+ audio : np.array
124
+ Audio data.
125
+ audio_sample_rate : int
126
+ Audio sample rate.
127
+ audio_window_size : int
128
+ Audio window size.
129
+ audio_window_stride : int
130
+ Audio window stride.
131
+ num_frames : int or None
132
+ Numbers of frames.
133
+ net_fn : func
134
+ Function for DeepSpeech model call.
135
+
136
+ Returns
137
+ -------
138
+ np.array
139
+ DeepSpeech features.
140
+ """
141
+ target_sample_rate = 16000
142
+ if audio_sample_rate != target_sample_rate:
143
+ resampled_audio = resampy.resample(
144
+ x=audio.astype(np.float),
145
+ sr_orig=audio_sample_rate,
146
+ sr_new=target_sample_rate)
147
+ else:
148
+ resampled_audio = audio.astype(np.float32)
149
+ input_vector = conv_audio_to_deepspeech_input_vector(
150
+ audio=resampled_audio.astype(np.int16),
151
+ sample_rate=target_sample_rate,
152
+ num_cepstrum=26,
153
+ num_context=9)
154
+
155
+ network_output = net_fn(input_vector)
156
+ # print(network_output.shape)
157
+
158
+ deepspeech_fps = 50
159
+ video_fps = 50 # Change this option if video fps is different
160
+ audio_len_s = float(audio.shape[0]) / audio_sample_rate
161
+ if num_frames is None:
162
+ num_frames = int(round(audio_len_s * video_fps))
163
+ else:
164
+ video_fps = num_frames / audio_len_s
165
+ network_output = interpolate_features(
166
+ features=network_output[:, 0],
167
+ input_rate=deepspeech_fps,
168
+ output_rate=video_fps,
169
+ output_len=num_frames)
170
+
171
+ # Make windows:
172
+ zero_pad = np.zeros((int(audio_window_size / 2), network_output.shape[1]))
173
+ network_output = np.concatenate(
174
+ (zero_pad, network_output, zero_pad), axis=0)
175
+ windows = []
176
+ for window_index in range(0, network_output.shape[0] - audio_window_size, audio_window_stride):
177
+ windows.append(
178
+ network_output[window_index:window_index + audio_window_size])
179
+
180
+ return np.array(windows)
181
+
182
+
183
+ def conv_audio_to_deepspeech_input_vector(audio,
184
+ sample_rate,
185
+ num_cepstrum,
186
+ num_context):
187
+ """
188
+ Convert audio raw data into DeepSpeech input vector.
189
+
190
+ Parameters
191
+ ----------
192
+ audio : np.array
193
+ Audio data.
194
+ audio_sample_rate : int
195
+ Audio sample rate.
196
+ num_cepstrum : int
197
+ Number of cepstrum.
198
+ num_context : int
199
+ Number of context.
200
+
201
+ Returns
202
+ -------
203
+ np.array
204
+ DeepSpeech input vector.
205
+ """
206
+ # Get mfcc coefficients:
207
+ features = mfcc(
208
+ signal=audio,
209
+ samplerate=sample_rate,
210
+ numcep=num_cepstrum)
211
+
212
+ # We only keep every second feature (BiRNN stride = 2):
213
+ features = features[::2]
214
+
215
+ # One stride per time step in the input:
216
+ num_strides = len(features)
217
+
218
+ # Add empty initial and final contexts:
219
+ empty_context = np.zeros((num_context, num_cepstrum), dtype=features.dtype)
220
+ features = np.concatenate((empty_context, features, empty_context))
221
+
222
+ # Create a view into the array with overlapping strides of size
223
+ # numcontext (past) + 1 (present) + numcontext (future):
224
+ window_size = 2 * num_context + 1
225
+ train_inputs = np.lib.stride_tricks.as_strided(
226
+ features,
227
+ shape=(num_strides, window_size, num_cepstrum),
228
+ strides=(features.strides[0],
229
+ features.strides[0], features.strides[1]),
230
+ writeable=False)
231
+
232
+ # Flatten the second and third dimensions:
233
+ train_inputs = np.reshape(train_inputs, [num_strides, -1])
234
+
235
+ train_inputs = np.copy(train_inputs)
236
+ train_inputs = (train_inputs - np.mean(train_inputs)) / \
237
+ np.std(train_inputs)
238
+
239
+ return train_inputs
240
+
241
+
242
+ def interpolate_features(features,
243
+ input_rate,
244
+ output_rate,
245
+ output_len):
246
+ """
247
+ Interpolate DeepSpeech features.
248
+
249
+ Parameters
250
+ ----------
251
+ features : np.array
252
+ DeepSpeech features.
253
+ input_rate : int
254
+ input rate (FPS).
255
+ output_rate : int
256
+ Output rate (FPS).
257
+ output_len : int
258
+ Output data length.
259
+
260
+ Returns
261
+ -------
262
+ np.array
263
+ Interpolated data.
264
+ """
265
+ input_len = features.shape[0]
266
+ num_features = features.shape[1]
267
+ input_timestamps = np.arange(input_len) / float(input_rate)
268
+ output_timestamps = np.arange(output_len) / float(output_rate)
269
+ output_features = np.zeros((output_len, num_features))
270
+ for feature_idx in range(num_features):
271
+ output_features[:, feature_idx] = np.interp(
272
+ x=output_timestamps,
273
+ xp=input_timestamps,
274
+ fp=features[:, feature_idx])
275
+ return output_features
data_utils/deepspeech_features/deepspeech_store.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Routines for loading DeepSpeech model.
3
+ """
4
+
5
+ __all__ = ['get_deepspeech_model_file']
6
+
7
+ import os
8
+ import zipfile
9
+ import logging
10
+ import hashlib
11
+
12
+
13
+ deepspeech_features_repo_url = 'https://github.com/osmr/deepspeech_features'
14
+
15
+
16
+ def get_deepspeech_model_file(local_model_store_dir_path=os.path.join("~", ".tensorflow", "models")):
17
+ """
18
+ Return location for the pretrained on local file system. This function will download from online model zoo when
19
+ model cannot be found or has mismatch. The root directory will be created if it doesn't exist.
20
+
21
+ Parameters
22
+ ----------
23
+ local_model_store_dir_path : str, default $TENSORFLOW_HOME/models
24
+ Location for keeping the model parameters.
25
+
26
+ Returns
27
+ -------
28
+ file_path
29
+ Path to the requested pretrained model file.
30
+ """
31
+ sha1_hash = "b90017e816572ddce84f5843f1fa21e6a377975e"
32
+ file_name = "deepspeech-0_1_0-b90017e8.pb"
33
+ local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path)
34
+ file_path = os.path.join(local_model_store_dir_path, file_name)
35
+ if os.path.exists(file_path):
36
+ if _check_sha1(file_path, sha1_hash):
37
+ return file_path
38
+ else:
39
+ logging.warning("Mismatch in the content of model file detected. Downloading again.")
40
+ else:
41
+ logging.info("Model file not found. Downloading to {}.".format(file_path))
42
+
43
+ if not os.path.exists(local_model_store_dir_path):
44
+ os.makedirs(local_model_store_dir_path)
45
+
46
+ zip_file_path = file_path + ".zip"
47
+ _download(
48
+ url="{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip".format(
49
+ repo_url=deepspeech_features_repo_url,
50
+ repo_release_tag="v0.0.1",
51
+ file_name=file_name),
52
+ path=zip_file_path,
53
+ overwrite=True)
54
+ with zipfile.ZipFile(zip_file_path) as zf:
55
+ zf.extractall(local_model_store_dir_path)
56
+ os.remove(zip_file_path)
57
+
58
+ if _check_sha1(file_path, sha1_hash):
59
+ return file_path
60
+ else:
61
+ raise ValueError("Downloaded file has different hash. Please try again.")
62
+
63
+
64
+ def _download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
65
+ """
66
+ Download an given URL
67
+
68
+ Parameters
69
+ ----------
70
+ url : str
71
+ URL to download
72
+ path : str, optional
73
+ Destination path to store downloaded file. By default stores to the
74
+ current directory with same name as in url.
75
+ overwrite : bool, optional
76
+ Whether to overwrite destination file if already exists.
77
+ sha1_hash : str, optional
78
+ Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
79
+ but doesn't match.
80
+ retries : integer, default 5
81
+ The number of times to attempt the download in case of failure or non 200 return codes
82
+ verify_ssl : bool, default True
83
+ Verify SSL certificates.
84
+
85
+ Returns
86
+ -------
87
+ str
88
+ The file path of the downloaded file.
89
+ """
90
+ import warnings
91
+ try:
92
+ import requests
93
+ except ImportError:
94
+ class requests_failed_to_import(object):
95
+ pass
96
+ requests = requests_failed_to_import
97
+
98
+ if path is None:
99
+ fname = url.split("/")[-1]
100
+ # Empty filenames are invalid
101
+ assert fname, "Can't construct file-name from this URL. Please set the `path` option manually."
102
+ else:
103
+ path = os.path.expanduser(path)
104
+ if os.path.isdir(path):
105
+ fname = os.path.join(path, url.split("/")[-1])
106
+ else:
107
+ fname = path
108
+ assert retries >= 0, "Number of retries should be at least 0"
109
+
110
+ if not verify_ssl:
111
+ warnings.warn(
112
+ "Unverified HTTPS request is being made (verify_ssl=False). "
113
+ "Adding certificate verification is strongly advised.")
114
+
115
+ if overwrite or not os.path.exists(fname) or (sha1_hash and not _check_sha1(fname, sha1_hash)):
116
+ dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
117
+ if not os.path.exists(dirname):
118
+ os.makedirs(dirname)
119
+ while retries + 1 > 0:
120
+ # Disable pyling too broad Exception
121
+ # pylint: disable=W0703
122
+ try:
123
+ print("Downloading {} from {}...".format(fname, url))
124
+ r = requests.get(url, stream=True, verify=verify_ssl)
125
+ if r.status_code != 200:
126
+ raise RuntimeError("Failed downloading url {}".format(url))
127
+ with open(fname, "wb") as f:
128
+ for chunk in r.iter_content(chunk_size=1024):
129
+ if chunk: # filter out keep-alive new chunks
130
+ f.write(chunk)
131
+ if sha1_hash and not _check_sha1(fname, sha1_hash):
132
+ raise UserWarning("File {} is downloaded but the content hash does not match."
133
+ " The repo may be outdated or download may be incomplete. "
134
+ "If the `repo_url` is overridden, consider switching to "
135
+ "the default repo.".format(fname))
136
+ break
137
+ except Exception as e:
138
+ retries -= 1
139
+ if retries <= 0:
140
+ raise e
141
+ else:
142
+ print("download failed, retrying, {} attempt{} left"
143
+ .format(retries, "s" if retries > 1 else ""))
144
+
145
+ return fname
146
+
147
+
148
+ def _check_sha1(filename, sha1_hash):
149
+ """
150
+ Check whether the sha1 hash of the file content matches the expected hash.
151
+
152
+ Parameters
153
+ ----------
154
+ filename : str
155
+ Path to the file.
156
+ sha1_hash : str
157
+ Expected sha1 hash in hexadecimal digits.
158
+
159
+ Returns
160
+ -------
161
+ bool
162
+ Whether the file content matches the expected hash.
163
+ """
164
+ sha1 = hashlib.sha1()
165
+ with open(filename, "rb") as f:
166
+ while True:
167
+ data = f.read(1048576)
168
+ if not data:
169
+ break
170
+ sha1.update(data)
171
+
172
+ return sha1.hexdigest() == sha1_hash
data_utils/deepspeech_features/extract_ds_features.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script for extracting DeepSpeech features from audio file.
3
+ """
4
+
5
+ import os
6
+ import argparse
7
+ import numpy as np
8
+ import pandas as pd
9
+ from deepspeech_store import get_deepspeech_model_file
10
+ from deepspeech_features import conv_audios_to_deepspeech
11
+
12
+
13
+ def parse_args():
14
+ """
15
+ Create python script parameters.
16
+ Returns
17
+ -------
18
+ ArgumentParser
19
+ Resulted args.
20
+ """
21
+ parser = argparse.ArgumentParser(
22
+ description="Extract DeepSpeech features from audio file",
23
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
24
+ parser.add_argument(
25
+ "--input",
26
+ type=str,
27
+ required=True,
28
+ help="path to input audio file or directory")
29
+ parser.add_argument(
30
+ "--output",
31
+ type=str,
32
+ help="path to output file with DeepSpeech features")
33
+ parser.add_argument(
34
+ "--deepspeech",
35
+ type=str,
36
+ help="path to DeepSpeech 0.1.0 frozen model")
37
+ parser.add_argument(
38
+ "--metainfo",
39
+ type=str,
40
+ help="path to file with meta-information")
41
+
42
+ args = parser.parse_args()
43
+ return args
44
+
45
+
46
+ def extract_features(in_audios,
47
+ out_files,
48
+ deepspeech_pb_path,
49
+ metainfo_file_path=None):
50
+ """
51
+ Real extract audio from video file.
52
+ Parameters
53
+ ----------
54
+ in_audios : list of str
55
+ Paths to input audio files.
56
+ out_files : list of str
57
+ Paths to output files with DeepSpeech features.
58
+ deepspeech_pb_path : str
59
+ Path to DeepSpeech 0.1.0 frozen model.
60
+ metainfo_file_path : str, default None
61
+ Path to file with meta-information.
62
+ """
63
+ #deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
64
+ if metainfo_file_path is None:
65
+ num_frames_info = [None] * len(in_audios)
66
+ else:
67
+ train_df = pd.read_csv(
68
+ metainfo_file_path,
69
+ sep="\t",
70
+ index_col=False,
71
+ dtype={"Id": np.int, "File": np.unicode, "Count": np.int})
72
+ num_frames_info = train_df["Count"].values
73
+ assert (len(num_frames_info) == len(in_audios))
74
+
75
+ for i, in_audio in enumerate(in_audios):
76
+ if not out_files[i]:
77
+ file_stem, _ = os.path.splitext(in_audio)
78
+ out_files[i] = file_stem + ".npy"
79
+ #print(out_files[i])
80
+ conv_audios_to_deepspeech(
81
+ audios=in_audios,
82
+ out_files=out_files,
83
+ num_frames_info=num_frames_info,
84
+ deepspeech_pb_path=deepspeech_pb_path)
85
+
86
+
87
+ def main():
88
+ """
89
+ Main body of script.
90
+ """
91
+ args = parse_args()
92
+ in_audio = os.path.expanduser(args.input)
93
+ if not os.path.exists(in_audio):
94
+ raise Exception("Input file/directory doesn't exist: {}".format(in_audio))
95
+ deepspeech_pb_path = args.deepspeech
96
+ #add
97
+ deepspeech_pb_path = True
98
+ args.deepspeech = '~/.tensorflow/models/deepspeech-0_1_0-b90017e8.pb'
99
+ #deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
100
+ if deepspeech_pb_path is None:
101
+ deepspeech_pb_path = ""
102
+ if deepspeech_pb_path:
103
+ deepspeech_pb_path = os.path.expanduser(args.deepspeech)
104
+ if not os.path.exists(deepspeech_pb_path):
105
+ deepspeech_pb_path = get_deepspeech_model_file()
106
+ if os.path.isfile(in_audio):
107
+ extract_features(
108
+ in_audios=[in_audio],
109
+ out_files=[args.output],
110
+ deepspeech_pb_path=deepspeech_pb_path,
111
+ metainfo_file_path=args.metainfo)
112
+ else:
113
+ audio_file_paths = []
114
+ for file_name in os.listdir(in_audio):
115
+ if not os.path.isfile(os.path.join(in_audio, file_name)):
116
+ continue
117
+ _, file_ext = os.path.splitext(file_name)
118
+ if file_ext.lower() == ".wav":
119
+ audio_file_path = os.path.join(in_audio, file_name)
120
+ audio_file_paths.append(audio_file_path)
121
+ audio_file_paths = sorted(audio_file_paths)
122
+ out_file_paths = [""] * len(audio_file_paths)
123
+ extract_features(
124
+ in_audios=audio_file_paths,
125
+ out_files=out_file_paths,
126
+ deepspeech_pb_path=deepspeech_pb_path,
127
+ metainfo_file_path=args.metainfo)
128
+
129
+
130
+ if __name__ == "__main__":
131
+ main()
132
+
data_utils/deepspeech_features/extract_wav.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script for extracting audio (16-bit, mono, 22000 Hz) from video file.
3
+ """
4
+
5
+ import os
6
+ import argparse
7
+ import subprocess
8
+
9
+
10
+ def parse_args():
11
+ """
12
+ Create python script parameters.
13
+
14
+ Returns
15
+ -------
16
+ ArgumentParser
17
+ Resulted args.
18
+ """
19
+ parser = argparse.ArgumentParser(
20
+ description="Extract audio from video file",
21
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
22
+ parser.add_argument(
23
+ "--in-video",
24
+ type=str,
25
+ required=True,
26
+ help="path to input video file or directory")
27
+ parser.add_argument(
28
+ "--out-audio",
29
+ type=str,
30
+ help="path to output audio file")
31
+
32
+ args = parser.parse_args()
33
+ return args
34
+
35
+
36
+ def extract_audio(in_video,
37
+ out_audio):
38
+ """
39
+ Real extract audio from video file.
40
+
41
+ Parameters
42
+ ----------
43
+ in_video : str
44
+ Path to input video file.
45
+ out_audio : str
46
+ Path to output audio file.
47
+ """
48
+ if not out_audio:
49
+ file_stem, _ = os.path.splitext(in_video)
50
+ out_audio = file_stem + ".wav"
51
+ # command1 = "ffmpeg -i {in_video} -vn -acodec copy {aac_audio}"
52
+ # command2 = "ffmpeg -i {aac_audio} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}"
53
+ # command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}"
54
+ command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 16000 {out_audio}"
55
+ subprocess.call([command.format(in_video=in_video, out_audio=out_audio)], shell=True)
56
+
57
+
58
+ def main():
59
+ """
60
+ Main body of script.
61
+ """
62
+ args = parse_args()
63
+ in_video = os.path.expanduser(args.in_video)
64
+ if not os.path.exists(in_video):
65
+ raise Exception("Input file/directory doesn't exist: {}".format(in_video))
66
+ if os.path.isfile(in_video):
67
+ extract_audio(
68
+ in_video=in_video,
69
+ out_audio=args.out_audio)
70
+ else:
71
+ video_file_paths = []
72
+ for file_name in os.listdir(in_video):
73
+ if not os.path.isfile(os.path.join(in_video, file_name)):
74
+ continue
75
+ _, file_ext = os.path.splitext(file_name)
76
+ if file_ext.lower() in (".mp4", ".mkv", ".avi"):
77
+ video_file_path = os.path.join(in_video, file_name)
78
+ video_file_paths.append(video_file_path)
79
+ video_file_paths = sorted(video_file_paths)
80
+ for video_file_path in video_file_paths:
81
+ extract_audio(
82
+ in_video=video_file_path,
83
+ out_audio="")
84
+
85
+
86
+ if __name__ == "__main__":
87
+ main()
data_utils/deepspeech_features/fea_win.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ net_output = np.load('french.ds.npy').reshape(-1, 29)
4
+ win_size = 16
5
+ zero_pad = np.zeros((int(win_size / 2), net_output.shape[1]))
6
+ net_output = np.concatenate((zero_pad, net_output, zero_pad), axis=0)
7
+ windows = []
8
+ for window_index in range(0, net_output.shape[0] - win_size, 2):
9
+ windows.append(net_output[window_index:window_index + win_size])
10
+ print(np.array(windows).shape)
11
+ np.save('aud_french.npy', np.array(windows))
data_utils/face_parsing/79999_iter.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:468e13ca13a9b43cc0881a9f99083a430e9c0a38abd935431d1c28ee94b26567
3
+ size 53289463
data_utils/face_parsing/logger.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+
5
+ import os.path as osp
6
+ import time
7
+ import sys
8
+ import logging
9
+
10
+ import torch.distributed as dist
11
+
12
+
13
+ def setup_logger(logpth):
14
+ logfile = 'BiSeNet-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S'))
15
+ logfile = osp.join(logpth, logfile)
16
+ FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s'
17
+ log_level = logging.INFO
18
+ if dist.is_initialized() and not dist.get_rank()==0:
19
+ log_level = logging.ERROR
20
+ logging.basicConfig(level=log_level, format=FORMAT, filename=logfile)
21
+ logging.root.addHandler(logging.StreamHandler())
22
+
23
+
data_utils/face_parsing/model.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchvision
9
+
10
+ from resnet import Resnet18
11
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
12
+
13
+
14
+ class ConvBNReLU(nn.Module):
15
+ def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
16
+ super(ConvBNReLU, self).__init__()
17
+ self.conv = nn.Conv2d(in_chan,
18
+ out_chan,
19
+ kernel_size = ks,
20
+ stride = stride,
21
+ padding = padding,
22
+ bias = False)
23
+ self.bn = nn.BatchNorm2d(out_chan)
24
+ self.init_weight()
25
+
26
+ def forward(self, x):
27
+ x = self.conv(x)
28
+ x = F.relu(self.bn(x))
29
+ return x
30
+
31
+ def init_weight(self):
32
+ for ly in self.children():
33
+ if isinstance(ly, nn.Conv2d):
34
+ nn.init.kaiming_normal_(ly.weight, a=1)
35
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
36
+
37
+ class BiSeNetOutput(nn.Module):
38
+ def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
39
+ super(BiSeNetOutput, self).__init__()
40
+ self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
41
+ self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
42
+ self.init_weight()
43
+
44
+ def forward(self, x):
45
+ x = self.conv(x)
46
+ x = self.conv_out(x)
47
+ return x
48
+
49
+ def init_weight(self):
50
+ for ly in self.children():
51
+ if isinstance(ly, nn.Conv2d):
52
+ nn.init.kaiming_normal_(ly.weight, a=1)
53
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
54
+
55
+ def get_params(self):
56
+ wd_params, nowd_params = [], []
57
+ for name, module in self.named_modules():
58
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
59
+ wd_params.append(module.weight)
60
+ if not module.bias is None:
61
+ nowd_params.append(module.bias)
62
+ elif isinstance(module, nn.BatchNorm2d):
63
+ nowd_params += list(module.parameters())
64
+ return wd_params, nowd_params
65
+
66
+
67
+ class AttentionRefinementModule(nn.Module):
68
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
69
+ super(AttentionRefinementModule, self).__init__()
70
+ self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
71
+ self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
72
+ self.bn_atten = nn.BatchNorm2d(out_chan)
73
+ self.sigmoid_atten = nn.Sigmoid()
74
+ self.init_weight()
75
+
76
+ def forward(self, x):
77
+ feat = self.conv(x)
78
+ atten = F.avg_pool2d(feat, feat.size()[2:])
79
+ atten = self.conv_atten(atten)
80
+ atten = self.bn_atten(atten)
81
+ atten = self.sigmoid_atten(atten)
82
+ out = torch.mul(feat, atten)
83
+ return out
84
+
85
+ def init_weight(self):
86
+ for ly in self.children():
87
+ if isinstance(ly, nn.Conv2d):
88
+ nn.init.kaiming_normal_(ly.weight, a=1)
89
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
90
+
91
+
92
+ class ContextPath(nn.Module):
93
+ def __init__(self, *args, **kwargs):
94
+ super(ContextPath, self).__init__()
95
+ self.resnet = Resnet18()
96
+ self.arm16 = AttentionRefinementModule(256, 128)
97
+ self.arm32 = AttentionRefinementModule(512, 128)
98
+ self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
99
+ self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
100
+ self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
101
+
102
+ self.init_weight()
103
+
104
+ def forward(self, x):
105
+ H0, W0 = x.size()[2:]
106
+ feat8, feat16, feat32 = self.resnet(x)
107
+ H8, W8 = feat8.size()[2:]
108
+ H16, W16 = feat16.size()[2:]
109
+ H32, W32 = feat32.size()[2:]
110
+
111
+ avg = F.avg_pool2d(feat32, feat32.size()[2:])
112
+ avg = self.conv_avg(avg)
113
+ avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
114
+
115
+ feat32_arm = self.arm32(feat32)
116
+ feat32_sum = feat32_arm + avg_up
117
+ feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
118
+ feat32_up = self.conv_head32(feat32_up)
119
+
120
+ feat16_arm = self.arm16(feat16)
121
+ feat16_sum = feat16_arm + feat32_up
122
+ feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
123
+ feat16_up = self.conv_head16(feat16_up)
124
+
125
+ return feat8, feat16_up, feat32_up # x8, x8, x16
126
+
127
+ def init_weight(self):
128
+ for ly in self.children():
129
+ if isinstance(ly, nn.Conv2d):
130
+ nn.init.kaiming_normal_(ly.weight, a=1)
131
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
132
+
133
+ def get_params(self):
134
+ wd_params, nowd_params = [], []
135
+ for name, module in self.named_modules():
136
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
137
+ wd_params.append(module.weight)
138
+ if not module.bias is None:
139
+ nowd_params.append(module.bias)
140
+ elif isinstance(module, nn.BatchNorm2d):
141
+ nowd_params += list(module.parameters())
142
+ return wd_params, nowd_params
143
+
144
+
145
+ ### This is not used, since I replace this with the resnet feature with the same size
146
+ class SpatialPath(nn.Module):
147
+ def __init__(self, *args, **kwargs):
148
+ super(SpatialPath, self).__init__()
149
+ self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
150
+ self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
151
+ self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
152
+ self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
153
+ self.init_weight()
154
+
155
+ def forward(self, x):
156
+ feat = self.conv1(x)
157
+ feat = self.conv2(feat)
158
+ feat = self.conv3(feat)
159
+ feat = self.conv_out(feat)
160
+ return feat
161
+
162
+ def init_weight(self):
163
+ for ly in self.children():
164
+ if isinstance(ly, nn.Conv2d):
165
+ nn.init.kaiming_normal_(ly.weight, a=1)
166
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
167
+
168
+ def get_params(self):
169
+ wd_params, nowd_params = [], []
170
+ for name, module in self.named_modules():
171
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
172
+ wd_params.append(module.weight)
173
+ if not module.bias is None:
174
+ nowd_params.append(module.bias)
175
+ elif isinstance(module, nn.BatchNorm2d):
176
+ nowd_params += list(module.parameters())
177
+ return wd_params, nowd_params
178
+
179
+
180
+ class FeatureFusionModule(nn.Module):
181
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
182
+ super(FeatureFusionModule, self).__init__()
183
+ self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
184
+ self.conv1 = nn.Conv2d(out_chan,
185
+ out_chan//4,
186
+ kernel_size = 1,
187
+ stride = 1,
188
+ padding = 0,
189
+ bias = False)
190
+ self.conv2 = nn.Conv2d(out_chan//4,
191
+ out_chan,
192
+ kernel_size = 1,
193
+ stride = 1,
194
+ padding = 0,
195
+ bias = False)
196
+ self.relu = nn.ReLU(inplace=True)
197
+ self.sigmoid = nn.Sigmoid()
198
+ self.init_weight()
199
+
200
+ def forward(self, fsp, fcp):
201
+ fcat = torch.cat([fsp, fcp], dim=1)
202
+ feat = self.convblk(fcat)
203
+ atten = F.avg_pool2d(feat, feat.size()[2:])
204
+ atten = self.conv1(atten)
205
+ atten = self.relu(atten)
206
+ atten = self.conv2(atten)
207
+ atten = self.sigmoid(atten)
208
+ feat_atten = torch.mul(feat, atten)
209
+ feat_out = feat_atten + feat
210
+ return feat_out
211
+
212
+ def init_weight(self):
213
+ for ly in self.children():
214
+ if isinstance(ly, nn.Conv2d):
215
+ nn.init.kaiming_normal_(ly.weight, a=1)
216
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
217
+
218
+ def get_params(self):
219
+ wd_params, nowd_params = [], []
220
+ for name, module in self.named_modules():
221
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
222
+ wd_params.append(module.weight)
223
+ if not module.bias is None:
224
+ nowd_params.append(module.bias)
225
+ elif isinstance(module, nn.BatchNorm2d):
226
+ nowd_params += list(module.parameters())
227
+ return wd_params, nowd_params
228
+
229
+
230
+ class BiSeNet(nn.Module):
231
+ def __init__(self, n_classes, *args, **kwargs):
232
+ super(BiSeNet, self).__init__()
233
+ self.cp = ContextPath()
234
+ ## here self.sp is deleted
235
+ self.ffm = FeatureFusionModule(256, 256)
236
+ self.conv_out = BiSeNetOutput(256, 256, n_classes)
237
+ self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
238
+ self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
239
+ self.init_weight()
240
+
241
+ def forward(self, x):
242
+ H, W = x.size()[2:]
243
+ feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
244
+ feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
245
+ feat_fuse = self.ffm(feat_sp, feat_cp8)
246
+
247
+ feat_out = self.conv_out(feat_fuse)
248
+ feat_out16 = self.conv_out16(feat_cp8)
249
+ feat_out32 = self.conv_out32(feat_cp16)
250
+
251
+ feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
252
+ feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
253
+ feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
254
+
255
+ # return feat_out, feat_out16, feat_out32
256
+ return feat_out
257
+
258
+ def init_weight(self):
259
+ for ly in self.children():
260
+ if isinstance(ly, nn.Conv2d):
261
+ nn.init.kaiming_normal_(ly.weight, a=1)
262
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
263
+
264
+ def get_params(self):
265
+ wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
266
+ for name, child in self.named_children():
267
+ child_wd_params, child_nowd_params = child.get_params()
268
+ if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
269
+ lr_mul_wd_params += child_wd_params
270
+ lr_mul_nowd_params += child_nowd_params
271
+ else:
272
+ wd_params += child_wd_params
273
+ nowd_params += child_nowd_params
274
+ return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
275
+
276
+
277
+ if __name__ == "__main__":
278
+ net = BiSeNet(19)
279
+ net.cuda()
280
+ net.eval()
281
+ in_ten = torch.randn(16, 3, 640, 480).cuda()
282
+ out, out16, out32 = net(in_ten)
283
+ print(out.shape)
284
+
285
+ net.get_params()
data_utils/face_parsing/resnet.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.model_zoo as modelzoo
8
+
9
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
10
+
11
+ resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
12
+
13
+
14
+ def conv3x3(in_planes, out_planes, stride=1):
15
+ """3x3 convolution with padding"""
16
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17
+ padding=1, bias=False)
18
+
19
+
20
+ class BasicBlock(nn.Module):
21
+ def __init__(self, in_chan, out_chan, stride=1):
22
+ super(BasicBlock, self).__init__()
23
+ self.conv1 = conv3x3(in_chan, out_chan, stride)
24
+ self.bn1 = nn.BatchNorm2d(out_chan)
25
+ self.conv2 = conv3x3(out_chan, out_chan)
26
+ self.bn2 = nn.BatchNorm2d(out_chan)
27
+ self.relu = nn.ReLU(inplace=True)
28
+ self.downsample = None
29
+ if in_chan != out_chan or stride != 1:
30
+ self.downsample = nn.Sequential(
31
+ nn.Conv2d(in_chan, out_chan,
32
+ kernel_size=1, stride=stride, bias=False),
33
+ nn.BatchNorm2d(out_chan),
34
+ )
35
+
36
+ def forward(self, x):
37
+ residual = self.conv1(x)
38
+ residual = F.relu(self.bn1(residual))
39
+ residual = self.conv2(residual)
40
+ residual = self.bn2(residual)
41
+
42
+ shortcut = x
43
+ if self.downsample is not None:
44
+ shortcut = self.downsample(x)
45
+
46
+ out = shortcut + residual
47
+ out = self.relu(out)
48
+ return out
49
+
50
+
51
+ def create_layer_basic(in_chan, out_chan, bnum, stride=1):
52
+ layers = [BasicBlock(in_chan, out_chan, stride=stride)]
53
+ for i in range(bnum-1):
54
+ layers.append(BasicBlock(out_chan, out_chan, stride=1))
55
+ return nn.Sequential(*layers)
56
+
57
+
58
+ class Resnet18(nn.Module):
59
+ def __init__(self):
60
+ super(Resnet18, self).__init__()
61
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62
+ bias=False)
63
+ self.bn1 = nn.BatchNorm2d(64)
64
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
65
+ self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
66
+ self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
67
+ self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
68
+ self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
69
+ self.init_weight()
70
+
71
+ def forward(self, x):
72
+ x = self.conv1(x)
73
+ x = F.relu(self.bn1(x))
74
+ x = self.maxpool(x)
75
+
76
+ x = self.layer1(x)
77
+ feat8 = self.layer2(x) # 1/8
78
+ feat16 = self.layer3(feat8) # 1/16
79
+ feat32 = self.layer4(feat16) # 1/32
80
+ return feat8, feat16, feat32
81
+
82
+ def init_weight(self):
83
+ state_dict = modelzoo.load_url(resnet18_url)
84
+ self_state_dict = self.state_dict()
85
+ for k, v in state_dict.items():
86
+ if 'fc' in k: continue
87
+ self_state_dict.update({k: v})
88
+ self.load_state_dict(self_state_dict)
89
+
90
+ def get_params(self):
91
+ wd_params, nowd_params = [], []
92
+ for name, module in self.named_modules():
93
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
94
+ wd_params.append(module.weight)
95
+ if not module.bias is None:
96
+ nowd_params.append(module.bias)
97
+ elif isinstance(module, nn.BatchNorm2d):
98
+ nowd_params += list(module.parameters())
99
+ return wd_params, nowd_params
100
+
101
+
102
+ if __name__ == "__main__":
103
+ net = Resnet18()
104
+ x = torch.randn(16, 3, 224, 224)
105
+ out = net(x)
106
+ print(out[0].size())
107
+ print(out[1].size())
108
+ print(out[2].size())
109
+ net.get_params()
data_utils/face_parsing/test.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+ import numpy as np
4
+ from model import BiSeNet
5
+
6
+ import torch
7
+
8
+ import os
9
+ import os.path as osp
10
+
11
+ from PIL import Image
12
+ import torchvision.transforms as transforms
13
+ import cv2
14
+ from pathlib import Path
15
+ import configargparse
16
+ import tqdm
17
+
18
+ # import ttach as tta
19
+
20
+ def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg',
21
+ img_size=(512, 512)):
22
+ im = np.array(im)
23
+ vis_im = im.copy().astype(np.uint8)
24
+ vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
25
+ vis_parsing_anno = cv2.resize(
26
+ vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
27
+ vis_parsing_anno_color = np.zeros(
28
+ (vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + np.array([255, 255, 255]) # + 255
29
+
30
+ num_of_class = np.max(vis_parsing_anno)
31
+ # print(num_of_class)
32
+ for pi in range(1, 14):
33
+ index = np.where(vis_parsing_anno == pi)
34
+ vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0])
35
+
36
+ for pi in range(14, 16):
37
+ index = np.where(vis_parsing_anno == pi)
38
+ vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 255, 0])
39
+ for pi in range(16, 17):
40
+ index = np.where(vis_parsing_anno == pi)
41
+ vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 0, 255])
42
+ for pi in range(17, num_of_class+1):
43
+ index = np.where(vis_parsing_anno == pi)
44
+ vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0])
45
+
46
+ vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
47
+ index = np.where(vis_parsing_anno == num_of_class-1)
48
+ vis_im = cv2.resize(vis_parsing_anno_color, img_size,
49
+ interpolation=cv2.INTER_NEAREST)
50
+ if save_im:
51
+ cv2.imwrite(save_path, vis_im)
52
+
53
+
54
+ def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'):
55
+
56
+ Path(respth).mkdir(parents=True, exist_ok=True)
57
+
58
+ print(f'[INFO] loading model...')
59
+ n_classes = 19
60
+ net = BiSeNet(n_classes=n_classes)
61
+ net.cuda()
62
+ net.load_state_dict(torch.load(cp))
63
+ net.eval()
64
+
65
+ to_tensor = transforms.Compose([
66
+ transforms.ToTensor(),
67
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
68
+ ])
69
+
70
+ image_paths = os.listdir(dspth)
71
+
72
+ with torch.no_grad():
73
+ for image_path in tqdm.tqdm(image_paths):
74
+ if image_path.endswith('.jpg') or image_path.endswith('.png'):
75
+ img = Image.open(osp.join(dspth, image_path))
76
+ ori_size = img.size
77
+ image = img.resize((512, 512), Image.BILINEAR)
78
+ image = image.convert("RGB")
79
+ img = to_tensor(image)
80
+
81
+ # test-time augmentation.
82
+ inputs = torch.unsqueeze(img, 0) # [1, 3, 512, 512]
83
+ outputs = net(inputs.cuda())
84
+ parsing = outputs.mean(0).cpu().numpy().argmax(0)
85
+
86
+ image_path = int(image_path[:-4])
87
+ image_path = str(image_path) + '.png'
88
+
89
+ vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path), img_size=ori_size)
90
+
91
+
92
+ if __name__ == "__main__":
93
+ parser = configargparse.ArgumentParser()
94
+ parser.add_argument('--respath', type=str, default='./result/', help='result path for label')
95
+ parser.add_argument('--imgpath', type=str, default='./imgs/', help='path for input images')
96
+ parser.add_argument('--modelpath', type=str, default='data_utils/face_parsing/79999_iter.pth')
97
+ args = parser.parse_args()
98
+ evaluate(respth=args.respath, dspth=args.imgpath, cp=args.modelpath)
data_utils/face_tracking/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data_utils/face_tracking/3DMM/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data_utils/face_tracking/3DMM/01_MorphableModel.mat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37b1f0742db356a3b1568a8365a06f5b0fe0ab687ac1c3068c803666cbd4d8e2
3
+ size 240875364
data_utils/face_tracking/3DMM/01_MorphableModel.mat.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5338eeb60e7b6702cf4ed728090ffe892b8dde0e42bba647184b0359f6e991aa
3
+ size 240178121
data_utils/face_tracking/3DMM/exp_info.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3196029ed038eb9a461df6c782125fc4d1ec1545f2e5f361891471136b6cbb6
3
+ size 33264853
data_utils/face_tracking/3DMM/keys_info.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:028d3c383bae129c4bdcac880e22c551a5d2436ec7db7a26f5c57148d12469e6
3
+ size 7375
data_utils/face_tracking/3DMM/sub_mesh.obj ADDED
The diff for this file is too large to render. See raw diff
 
data_utils/face_tracking/3DMM/topology_info.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2edff6a6ad574d2dddf0d0815e0beabfea9369d7c2d6e53e0ba81f809b81e963
3
+ size 4145201
data_utils/face_tracking/__init__.py ADDED
File without changes
data_utils/face_tracking/convert_BFM.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.io import loadmat
3
+
4
+ original_BFM = loadmat("3DMM/01_MorphableModel.mat")
5
+ sub_inds = np.load("3DMM/topology_info.npy", allow_pickle=True).item()["sub_inds"]
6
+
7
+ shapePC = original_BFM["shapePC"]
8
+ shapeEV = original_BFM["shapeEV"]
9
+ shapeMU = original_BFM["shapeMU"]
10
+ texPC = original_BFM["texPC"]
11
+ texEV = original_BFM["texEV"]
12
+ texMU = original_BFM["texMU"]
13
+
14
+ b_shape = shapePC.reshape(-1, 199).transpose(1, 0).reshape(199, -1, 3)
15
+ mu_shape = shapeMU.reshape(-1, 3)
16
+
17
+ b_tex = texPC.reshape(-1, 199).transpose(1, 0).reshape(199, -1, 3)
18
+ mu_tex = texMU.reshape(-1, 3)
19
+
20
+ b_shape = b_shape[:, sub_inds, :].reshape(199, -1)
21
+ mu_shape = mu_shape[sub_inds, :].reshape(-1)
22
+ b_tex = b_tex[:, sub_inds, :].reshape(199, -1)
23
+ mu_tex = mu_tex[sub_inds, :].reshape(-1)
24
+
25
+ exp_info = np.load("3DMM/exp_info.npy", allow_pickle=True).item()
26
+ np.save(
27
+ "3DMM/3DMM_info.npy",
28
+ {
29
+ "mu_shape": mu_shape,
30
+ "b_shape": b_shape,
31
+ "sig_shape": shapeEV.reshape(-1),
32
+ "mu_exp": exp_info["mu_exp"],
33
+ "b_exp": exp_info["base_exp"],
34
+ "sig_exp": exp_info["sig_exp"],
35
+ "mu_tex": mu_tex,
36
+ "b_tex": b_tex,
37
+ "sig_tex": texEV.reshape(-1),
38
+ },
39
+ )
data_utils/face_tracking/data_loader.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+
5
+
6
+ def load_dir(path, start, end):
7
+ lmss = []
8
+ imgs_paths = []
9
+ for i in range(start, end):
10
+ if os.path.isfile(os.path.join(path, str(i) + ".lms")):
11
+ lms = np.loadtxt(os.path.join(path, str(i) + ".lms"), dtype=np.float32)
12
+ lmss.append(lms)
13
+ imgs_paths.append(os.path.join(path, str(i) + ".jpg"))
14
+ lmss = np.stack(lmss)
15
+ lmss = torch.as_tensor(lmss).cuda()
16
+ return lmss, imgs_paths
data_utils/face_tracking/face_tracker.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import cv2
4
+ import argparse
5
+ from pathlib import Path
6
+ import torch
7
+ import numpy as np
8
+ from data_loader import load_dir
9
+ from facemodel import Face_3DMM
10
+ from util import *
11
+ from render_3dmm import Render_3DMM
12
+
13
+
14
+ # torch.autograd.set_detect_anomaly(True)
15
+
16
+ dir_path = os.path.dirname(os.path.realpath(__file__))
17
+
18
+
19
+ def set_requires_grad(tensor_list):
20
+ for tensor in tensor_list:
21
+ tensor.requires_grad = True
22
+
23
+
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument(
26
+ "--path", type=str, default="obama/ori_imgs", help="idname of target person"
27
+ )
28
+ parser.add_argument("--img_h", type=int, default=512, help="image height")
29
+ parser.add_argument("--img_w", type=int, default=512, help="image width")
30
+ parser.add_argument("--frame_num", type=int, default=11000, help="image number")
31
+ args = parser.parse_args()
32
+
33
+ start_id = 0
34
+ end_id = args.frame_num
35
+
36
+ lms, img_paths = load_dir(args.path, start_id, end_id)
37
+ num_frames = lms.shape[0]
38
+ h, w = args.img_h, args.img_w
39
+ cxy = torch.tensor((w / 2.0, h / 2.0), dtype=torch.float).cuda()
40
+ id_dim, exp_dim, tex_dim, point_num = 100, 79, 100, 34650
41
+ model_3dmm = Face_3DMM(
42
+ os.path.join(dir_path, "3DMM"), id_dim, exp_dim, tex_dim, point_num
43
+ )
44
+
45
+ # only use one image per 40 to do fit the focal length
46
+ sel_ids = np.arange(0, num_frames, 40)
47
+ sel_num = sel_ids.shape[0]
48
+ arg_focal = 1600
49
+ arg_landis = 1e5
50
+
51
+ print(f'[INFO] fitting focal length...')
52
+
53
+ # fit the focal length
54
+ for focal in range(600, 1500, 100):
55
+ id_para = lms.new_zeros((1, id_dim), requires_grad=True)
56
+ exp_para = lms.new_zeros((sel_num, exp_dim), requires_grad=True)
57
+ euler_angle = lms.new_zeros((sel_num, 3), requires_grad=True)
58
+ trans = lms.new_zeros((sel_num, 3), requires_grad=True)
59
+ trans.data[:, 2] -= 7
60
+ focal_length = lms.new_zeros(1, requires_grad=False)
61
+ focal_length.data += focal
62
+ set_requires_grad([id_para, exp_para, euler_angle, trans])
63
+
64
+ optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=0.1)
65
+ optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=0.1)
66
+
67
+ for iter in range(2000):
68
+ id_para_batch = id_para.expand(sel_num, -1)
69
+ geometry = model_3dmm.get_3dlandmarks(
70
+ id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
71
+ )
72
+ proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
73
+ loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms[sel_ids].detach())
74
+ loss = loss_lan
75
+ optimizer_frame.zero_grad()
76
+ loss.backward()
77
+ optimizer_frame.step()
78
+ # if iter % 100 == 0:
79
+ # print(focal, 'pose', iter, loss.item())
80
+
81
+ for iter in range(2500):
82
+ id_para_batch = id_para.expand(sel_num, -1)
83
+ geometry = model_3dmm.get_3dlandmarks(
84
+ id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
85
+ )
86
+ proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
87
+ loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms[sel_ids].detach())
88
+ loss_regid = torch.mean(id_para * id_para)
89
+ loss_regexp = torch.mean(exp_para * exp_para)
90
+ loss = loss_lan + loss_regid * 0.5 + loss_regexp * 0.4
91
+ optimizer_idexp.zero_grad()
92
+ optimizer_frame.zero_grad()
93
+ loss.backward()
94
+ optimizer_idexp.step()
95
+ optimizer_frame.step()
96
+ # if iter % 100 == 0:
97
+ # print(focal, 'poseidexp', iter, loss_lan.item(), loss_regid.item(), loss_regexp.item())
98
+
99
+ if iter % 1500 == 0 and iter >= 1500:
100
+ for param_group in optimizer_idexp.param_groups:
101
+ param_group["lr"] *= 0.2
102
+ for param_group in optimizer_frame.param_groups:
103
+ param_group["lr"] *= 0.2
104
+
105
+ print(focal, loss_lan.item(), torch.mean(trans[:, 2]).item())
106
+
107
+ if loss_lan.item() < arg_landis:
108
+ arg_landis = loss_lan.item()
109
+ arg_focal = focal
110
+
111
+ print("[INFO] find best focal:", arg_focal)
112
+
113
+ print(f'[INFO] coarse fitting...')
114
+
115
+ # for all frames, do a coarse fitting ???
116
+ id_para = lms.new_zeros((1, id_dim), requires_grad=True)
117
+ exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
118
+ tex_para = lms.new_zeros(
119
+ (1, tex_dim), requires_grad=True
120
+ ) # not optimized in this block ???
121
+ euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True)
122
+ trans = lms.new_zeros((num_frames, 3), requires_grad=True)
123
+ light_para = lms.new_zeros((num_frames, 27), requires_grad=True)
124
+ trans.data[:, 2] -= 7 # ???
125
+ focal_length = lms.new_zeros(1, requires_grad=True)
126
+ focal_length.data += arg_focal
127
+
128
+ set_requires_grad([id_para, exp_para, tex_para, euler_angle, trans, light_para])
129
+
130
+ optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=0.1)
131
+ optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=1)
132
+
133
+ for iter in range(1500):
134
+ id_para_batch = id_para.expand(num_frames, -1)
135
+ geometry = model_3dmm.get_3dlandmarks(
136
+ id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
137
+ )
138
+ proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
139
+ loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms.detach())
140
+ loss = loss_lan
141
+ optimizer_frame.zero_grad()
142
+ loss.backward()
143
+ optimizer_frame.step()
144
+ if iter == 1000:
145
+ for param_group in optimizer_frame.param_groups:
146
+ param_group["lr"] = 0.1
147
+ # if iter % 100 == 0:
148
+ # print('pose', iter, loss.item())
149
+
150
+ for param_group in optimizer_frame.param_groups:
151
+ param_group["lr"] = 0.1
152
+
153
+ for iter in range(2000):
154
+ id_para_batch = id_para.expand(num_frames, -1)
155
+ geometry = model_3dmm.get_3dlandmarks(
156
+ id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
157
+ )
158
+ proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
159
+ loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms.detach())
160
+ loss_regid = torch.mean(id_para * id_para)
161
+ loss_regexp = torch.mean(exp_para * exp_para)
162
+ loss = loss_lan + loss_regid * 0.5 + loss_regexp * 0.4
163
+ optimizer_idexp.zero_grad()
164
+ optimizer_frame.zero_grad()
165
+ loss.backward()
166
+ optimizer_idexp.step()
167
+ optimizer_frame.step()
168
+ # if iter % 100 == 0:
169
+ # print('poseidexp', iter, loss_lan.item(), loss_regid.item(), loss_regexp.item())
170
+ if iter % 1000 == 0 and iter >= 1000:
171
+ for param_group in optimizer_idexp.param_groups:
172
+ param_group["lr"] *= 0.2
173
+ for param_group in optimizer_frame.param_groups:
174
+ param_group["lr"] *= 0.2
175
+
176
+ print(loss_lan.item(), torch.mean(trans[:, 2]).item())
177
+
178
+ print(f'[INFO] fitting light...')
179
+
180
+ batch_size = 32
181
+
182
+ device_default = torch.device("cuda:0")
183
+ device_render = torch.device("cuda:0")
184
+ renderer = Render_3DMM(arg_focal, h, w, batch_size, device_render)
185
+
186
+ sel_ids = np.arange(0, num_frames, int(num_frames / batch_size))[:batch_size]
187
+ imgs = []
188
+ for sel_id in sel_ids:
189
+ imgs.append(cv2.imread(img_paths[sel_id])[:, :, ::-1])
190
+ imgs = np.stack(imgs)
191
+ sel_imgs = torch.as_tensor(imgs).cuda()
192
+ sel_lms = lms[sel_ids]
193
+ sel_light = light_para.new_zeros((batch_size, 27), requires_grad=True)
194
+ set_requires_grad([sel_light])
195
+
196
+ optimizer_tl = torch.optim.Adam([tex_para, sel_light], lr=0.1)
197
+ optimizer_id_frame = torch.optim.Adam([euler_angle, trans, exp_para, id_para], lr=0.01)
198
+
199
+ for iter in range(71):
200
+ sel_exp_para, sel_euler, sel_trans = (
201
+ exp_para[sel_ids],
202
+ euler_angle[sel_ids],
203
+ trans[sel_ids],
204
+ )
205
+ sel_id_para = id_para.expand(batch_size, -1)
206
+ geometry = model_3dmm.get_3dlandmarks(
207
+ sel_id_para, sel_exp_para, sel_euler, sel_trans, focal_length, cxy
208
+ )
209
+ proj_geo = forward_transform(geometry, sel_euler, sel_trans, focal_length, cxy)
210
+
211
+ loss_lan = cal_lan_loss(proj_geo[:, :, :2], sel_lms.detach())
212
+ loss_regid = torch.mean(id_para * id_para)
213
+ loss_regexp = torch.mean(sel_exp_para * sel_exp_para)
214
+
215
+ sel_tex_para = tex_para.expand(batch_size, -1)
216
+ sel_texture = model_3dmm.forward_tex(sel_tex_para)
217
+ geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para)
218
+ rott_geo = forward_rott(geometry, sel_euler, sel_trans)
219
+ render_imgs = renderer(
220
+ rott_geo.to(device_render),
221
+ sel_texture.to(device_render),
222
+ sel_light.to(device_render),
223
+ )
224
+ render_imgs = render_imgs.to(device_default)
225
+
226
+ mask = (render_imgs[:, :, :, 3]).detach() > 0.0
227
+ render_proj = sel_imgs.clone()
228
+ render_proj[mask] = render_imgs[mask][..., :3].byte()
229
+ loss_col = cal_col_loss(render_imgs[:, :, :, :3], sel_imgs.float(), mask)
230
+
231
+ if iter > 50:
232
+ loss = loss_col + loss_lan * 0.05 + loss_regid * 1.0 + loss_regexp * 0.8
233
+ else:
234
+ loss = loss_col + loss_lan * 3 + loss_regid * 2.0 + loss_regexp * 1.0
235
+
236
+ optimizer_tl.zero_grad()
237
+ optimizer_id_frame.zero_grad()
238
+ loss.backward()
239
+
240
+ optimizer_tl.step()
241
+ optimizer_id_frame.step()
242
+
243
+ if iter % 50 == 0 and iter > 0:
244
+ for param_group in optimizer_id_frame.param_groups:
245
+ param_group["lr"] *= 0.2
246
+ for param_group in optimizer_tl.param_groups:
247
+ param_group["lr"] *= 0.2
248
+ # print(iter, loss_col.item(), loss_lan.item(), loss_regid.item(), loss_regexp.item())
249
+
250
+
251
+ light_mean = torch.mean(sel_light, 0).unsqueeze(0).repeat(num_frames, 1)
252
+ light_para.data = light_mean
253
+
254
+ exp_para = exp_para.detach()
255
+ euler_angle = euler_angle.detach()
256
+ trans = trans.detach()
257
+ light_para = light_para.detach()
258
+
259
+ print(f'[INFO] fine frame-wise fitting...')
260
+
261
+ for i in range(int((num_frames - 1) / batch_size + 1)):
262
+
263
+ if (i + 1) * batch_size > num_frames:
264
+ start_n = num_frames - batch_size
265
+ sel_ids = np.arange(num_frames - batch_size, num_frames)
266
+ else:
267
+ start_n = i * batch_size
268
+ sel_ids = np.arange(i * batch_size, i * batch_size + batch_size)
269
+
270
+ imgs = []
271
+ for sel_id in sel_ids:
272
+ imgs.append(cv2.imread(img_paths[sel_id])[:, :, ::-1])
273
+ imgs = np.stack(imgs)
274
+ sel_imgs = torch.as_tensor(imgs).cuda()
275
+ sel_lms = lms[sel_ids]
276
+
277
+ sel_exp_para = exp_para.new_zeros((batch_size, exp_dim), requires_grad=True)
278
+ sel_exp_para.data = exp_para[sel_ids].clone()
279
+ sel_euler = euler_angle.new_zeros((batch_size, 3), requires_grad=True)
280
+ sel_euler.data = euler_angle[sel_ids].clone()
281
+ sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True)
282
+ sel_trans.data = trans[sel_ids].clone()
283
+ sel_light = light_para.new_zeros((batch_size, 27), requires_grad=True)
284
+ sel_light.data = light_para[sel_ids].clone()
285
+
286
+ set_requires_grad([sel_exp_para, sel_euler, sel_trans, sel_light])
287
+
288
+ optimizer_cur_batch = torch.optim.Adam(
289
+ [sel_exp_para, sel_euler, sel_trans, sel_light], lr=0.005
290
+ )
291
+
292
+ sel_id_para = id_para.expand(batch_size, -1).detach()
293
+ sel_tex_para = tex_para.expand(batch_size, -1).detach()
294
+
295
+ pre_num = 5
296
+
297
+ if i > 0:
298
+ pre_ids = np.arange(start_n - pre_num, start_n)
299
+
300
+ for iter in range(50):
301
+
302
+ geometry = model_3dmm.get_3dlandmarks(
303
+ sel_id_para, sel_exp_para, sel_euler, sel_trans, focal_length, cxy
304
+ )
305
+ proj_geo = forward_transform(geometry, sel_euler, sel_trans, focal_length, cxy)
306
+ loss_lan = cal_lan_loss(proj_geo[:, :, :2], sel_lms.detach())
307
+ loss_regexp = torch.mean(sel_exp_para * sel_exp_para)
308
+
309
+ sel_geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para)
310
+ sel_texture = model_3dmm.forward_tex(sel_tex_para)
311
+ geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para)
312
+ rott_geo = forward_rott(geometry, sel_euler, sel_trans)
313
+ render_imgs = renderer(
314
+ rott_geo.to(device_render),
315
+ sel_texture.to(device_render),
316
+ sel_light.to(device_render),
317
+ )
318
+ render_imgs = render_imgs.to(device_default)
319
+
320
+ mask = (render_imgs[:, :, :, 3]).detach() > 0.0
321
+
322
+ loss_col = cal_col_loss(render_imgs[:, :, :, :3], sel_imgs.float(), mask)
323
+
324
+ if i > 0:
325
+ geometry_lap = model_3dmm.forward_geo_sub(
326
+ id_para.expand(batch_size + pre_num, -1).detach(),
327
+ torch.cat((exp_para[pre_ids].detach(), sel_exp_para)),
328
+ model_3dmm.rigid_ids,
329
+ )
330
+ rott_geo_lap = forward_rott(
331
+ geometry_lap,
332
+ torch.cat((euler_angle[pre_ids].detach(), sel_euler)),
333
+ torch.cat((trans[pre_ids].detach(), sel_trans)),
334
+ )
335
+ loss_lap = cal_lap_loss(
336
+ [rott_geo_lap.reshape(rott_geo_lap.shape[0], -1).permute(1, 0)], [1.0]
337
+ )
338
+ else:
339
+ geometry_lap = model_3dmm.forward_geo_sub(
340
+ id_para.expand(batch_size, -1).detach(),
341
+ sel_exp_para,
342
+ model_3dmm.rigid_ids,
343
+ )
344
+ rott_geo_lap = forward_rott(geometry_lap, sel_euler, sel_trans)
345
+ loss_lap = cal_lap_loss(
346
+ [rott_geo_lap.reshape(rott_geo_lap.shape[0], -1).permute(1, 0)], [1.0]
347
+ )
348
+
349
+
350
+ if iter > 30:
351
+ loss = loss_col * 0.5 + loss_lan * 1.5 + loss_lap * 100000 + loss_regexp * 1.0
352
+ else:
353
+ loss = loss_col * 0.5 + loss_lan * 8 + loss_lap * 100000 + loss_regexp * 1.0
354
+
355
+ optimizer_cur_batch.zero_grad()
356
+ loss.backward()
357
+ optimizer_cur_batch.step()
358
+
359
+ # if iter % 10 == 0:
360
+ # print(
361
+ # i,
362
+ # iter,
363
+ # loss_col.item(),
364
+ # loss_lan.item(),
365
+ # loss_lap.item(),
366
+ # loss_regexp.item(),
367
+ # )
368
+
369
+ print(str(i) + " of " + str(int((num_frames - 1) / batch_size + 1)) + " done")
370
+
371
+ render_proj = sel_imgs.clone()
372
+ render_proj[mask] = render_imgs[mask][..., :3].byte()
373
+
374
+ exp_para[sel_ids] = sel_exp_para.clone()
375
+ euler_angle[sel_ids] = sel_euler.clone()
376
+ trans[sel_ids] = sel_trans.clone()
377
+ light_para[sel_ids] = sel_light.clone()
378
+
379
+ torch.save(
380
+ {
381
+ "id": id_para.detach().cpu(),
382
+ "exp": exp_para.detach().cpu(),
383
+ "euler": euler_angle.detach().cpu(),
384
+ "trans": trans.detach().cpu(),
385
+ "focal": focal_length.detach().cpu(),
386
+ },
387
+ os.path.join(os.path.dirname(args.path), "track_params.pt"),
388
+ )
389
+
390
+ print("params saved")
data_utils/face_tracking/facemodel.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import os
5
+ from util import *
6
+
7
+
8
+ class Face_3DMM(nn.Module):
9
+ def __init__(self, modelpath, id_dim, exp_dim, tex_dim, point_num):
10
+ super(Face_3DMM, self).__init__()
11
+ # id_dim = 100
12
+ # exp_dim = 79
13
+ # tex_dim = 100
14
+ self.point_num = point_num
15
+ DMM_info = np.load(
16
+ os.path.join(modelpath, "3DMM_info.npy"), allow_pickle=True
17
+ ).item()
18
+ base_id = DMM_info["b_shape"][:id_dim, :]
19
+ mu_id = DMM_info["mu_shape"]
20
+ base_exp = DMM_info["b_exp"][:exp_dim, :]
21
+ mu_exp = DMM_info["mu_exp"]
22
+ mu = mu_id + mu_exp
23
+ mu = mu.reshape(-1, 3)
24
+ for i in range(3):
25
+ mu[:, i] -= np.mean(mu[:, i])
26
+ mu = mu.reshape(-1)
27
+ self.base_id = torch.as_tensor(base_id).cuda() / 100000.0
28
+ self.base_exp = torch.as_tensor(base_exp).cuda() / 100000.0
29
+ self.mu = torch.as_tensor(mu).cuda() / 100000.0
30
+ base_tex = DMM_info["b_tex"][:tex_dim, :]
31
+ mu_tex = DMM_info["mu_tex"]
32
+ self.base_tex = torch.as_tensor(base_tex).cuda()
33
+ self.mu_tex = torch.as_tensor(mu_tex).cuda()
34
+ sig_id = DMM_info["sig_shape"][:id_dim]
35
+ sig_tex = DMM_info["sig_tex"][:tex_dim]
36
+ sig_exp = DMM_info["sig_exp"][:exp_dim]
37
+ self.sig_id = torch.as_tensor(sig_id).cuda()
38
+ self.sig_tex = torch.as_tensor(sig_tex).cuda()
39
+ self.sig_exp = torch.as_tensor(sig_exp).cuda()
40
+
41
+ keys_info = np.load(
42
+ os.path.join(modelpath, "keys_info.npy"), allow_pickle=True
43
+ ).item()
44
+ self.keyinds = torch.as_tensor(keys_info["keyinds"]).cuda()
45
+ self.left_contours = torch.as_tensor(keys_info["left_contour"]).cuda()
46
+ self.right_contours = torch.as_tensor(keys_info["right_contour"]).cuda()
47
+ self.rigid_ids = torch.as_tensor(keys_info["rigid_ids"]).cuda()
48
+
49
+ def get_3dlandmarks(self, id_para, exp_para, euler_angle, trans, focal_length, cxy):
50
+ id_para = id_para * self.sig_id
51
+ exp_para = exp_para * self.sig_exp
52
+ batch_size = id_para.shape[0]
53
+ num_per_contour = self.left_contours.shape[1]
54
+ left_contours_flat = self.left_contours.reshape(-1)
55
+ right_contours_flat = self.right_contours.reshape(-1)
56
+ sel_index = torch.cat(
57
+ (
58
+ 3 * left_contours_flat.unsqueeze(1),
59
+ 3 * left_contours_flat.unsqueeze(1) + 1,
60
+ 3 * left_contours_flat.unsqueeze(1) + 2,
61
+ ),
62
+ dim=1,
63
+ ).reshape(-1)
64
+ left_geometry = (
65
+ torch.mm(id_para, self.base_id[:, sel_index])
66
+ + torch.mm(exp_para, self.base_exp[:, sel_index])
67
+ + self.mu[sel_index]
68
+ )
69
+ left_geometry = left_geometry.view(batch_size, -1, 3)
70
+ proj_x = forward_transform(
71
+ left_geometry, euler_angle, trans, focal_length, cxy
72
+ )[:, :, 0]
73
+ proj_x = proj_x.reshape(batch_size, 8, num_per_contour)
74
+ arg_min = proj_x.argmin(dim=2)
75
+ left_geometry = left_geometry.view(batch_size * 8, num_per_contour, 3)
76
+ left_3dlands = left_geometry[
77
+ torch.arange(batch_size * 8), arg_min.view(-1), :
78
+ ].view(batch_size, 8, 3)
79
+
80
+ sel_index = torch.cat(
81
+ (
82
+ 3 * right_contours_flat.unsqueeze(1),
83
+ 3 * right_contours_flat.unsqueeze(1) + 1,
84
+ 3 * right_contours_flat.unsqueeze(1) + 2,
85
+ ),
86
+ dim=1,
87
+ ).reshape(-1)
88
+ right_geometry = (
89
+ torch.mm(id_para, self.base_id[:, sel_index])
90
+ + torch.mm(exp_para, self.base_exp[:, sel_index])
91
+ + self.mu[sel_index]
92
+ )
93
+ right_geometry = right_geometry.view(batch_size, -1, 3)
94
+ proj_x = forward_transform(
95
+ right_geometry, euler_angle, trans, focal_length, cxy
96
+ )[:, :, 0]
97
+ proj_x = proj_x.reshape(batch_size, 8, num_per_contour)
98
+ arg_max = proj_x.argmax(dim=2)
99
+ right_geometry = right_geometry.view(batch_size * 8, num_per_contour, 3)
100
+ right_3dlands = right_geometry[
101
+ torch.arange(batch_size * 8), arg_max.view(-1), :
102
+ ].view(batch_size, 8, 3)
103
+
104
+ sel_index = torch.cat(
105
+ (
106
+ 3 * self.keyinds.unsqueeze(1),
107
+ 3 * self.keyinds.unsqueeze(1) + 1,
108
+ 3 * self.keyinds.unsqueeze(1) + 2,
109
+ ),
110
+ dim=1,
111
+ ).reshape(-1)
112
+ geometry = (
113
+ torch.mm(id_para, self.base_id[:, sel_index])
114
+ + torch.mm(exp_para, self.base_exp[:, sel_index])
115
+ + self.mu[sel_index]
116
+ )
117
+ lands_3d = geometry.view(-1, self.keyinds.shape[0], 3)
118
+ lands_3d[:, :8, :] = left_3dlands
119
+ lands_3d[:, 9:17, :] = right_3dlands
120
+ return lands_3d
121
+
122
+ def forward_geo_sub(self, id_para, exp_para, sub_index):
123
+ id_para = id_para * self.sig_id
124
+ exp_para = exp_para * self.sig_exp
125
+ sel_index = torch.cat(
126
+ (
127
+ 3 * sub_index.unsqueeze(1),
128
+ 3 * sub_index.unsqueeze(1) + 1,
129
+ 3 * sub_index.unsqueeze(1) + 2,
130
+ ),
131
+ dim=1,
132
+ ).reshape(-1)
133
+ geometry = (
134
+ torch.mm(id_para, self.base_id[:, sel_index])
135
+ + torch.mm(exp_para, self.base_exp[:, sel_index])
136
+ + self.mu[sel_index]
137
+ )
138
+ return geometry.reshape(-1, sub_index.shape[0], 3)
139
+
140
+ def forward_geo(self, id_para, exp_para):
141
+ id_para = id_para * self.sig_id
142
+ exp_para = exp_para * self.sig_exp
143
+ geometry = (
144
+ torch.mm(id_para, self.base_id)
145
+ + torch.mm(exp_para, self.base_exp)
146
+ + self.mu
147
+ )
148
+ return geometry.reshape(-1, self.point_num, 3)
149
+
150
+ def forward_tex(self, tex_para):
151
+ tex_para = tex_para * self.sig_tex
152
+ texture = torch.mm(tex_para, self.base_tex) + self.mu_tex
153
+ return texture.reshape(-1, self.point_num, 3)
data_utils/face_tracking/geo_transform.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module contains functions for geometry transform and camera projection"""
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+
6
+
7
+ def euler2rot(euler_angle):
8
+ batch_size = euler_angle.shape[0]
9
+ theta = euler_angle[:, 0].reshape(-1, 1, 1)
10
+ phi = euler_angle[:, 1].reshape(-1, 1, 1)
11
+ psi = euler_angle[:, 2].reshape(-1, 1, 1)
12
+ one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device)
13
+ zero = torch.zeros(
14
+ (batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device
15
+ )
16
+ rot_x = torch.cat(
17
+ (
18
+ torch.cat((one, zero, zero), 1),
19
+ torch.cat((zero, theta.cos(), theta.sin()), 1),
20
+ torch.cat((zero, -theta.sin(), theta.cos()), 1),
21
+ ),
22
+ 2,
23
+ )
24
+ rot_y = torch.cat(
25
+ (
26
+ torch.cat((phi.cos(), zero, -phi.sin()), 1),
27
+ torch.cat((zero, one, zero), 1),
28
+ torch.cat((phi.sin(), zero, phi.cos()), 1),
29
+ ),
30
+ 2,
31
+ )
32
+ rot_z = torch.cat(
33
+ (
34
+ torch.cat((psi.cos(), -psi.sin(), zero), 1),
35
+ torch.cat((psi.sin(), psi.cos(), zero), 1),
36
+ torch.cat((zero, zero, one), 1),
37
+ ),
38
+ 2,
39
+ )
40
+ return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
41
+
42
+
43
+ def rot_trans_geo(geometry, rot, trans):
44
+ rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans.view(-1, 3, 1)
45
+ return rott_geo.permute(0, 2, 1)
46
+
47
+
48
+ def euler_trans_geo(geometry, euler, trans):
49
+ rot = euler2rot(euler)
50
+ return rot_trans_geo(geometry, rot, trans)
51
+
52
+
53
+ def proj_geo(rott_geo, camera_para):
54
+ fx = camera_para[:, 0]
55
+ fy = camera_para[:, 0]
56
+ cx = camera_para[:, 1]
57
+ cy = camera_para[:, 2]
58
+
59
+ X = rott_geo[:, :, 0]
60
+ Y = rott_geo[:, :, 1]
61
+ Z = rott_geo[:, :, 2]
62
+
63
+ fxX = fx[:, None] * X
64
+ fyY = fy[:, None] * Y
65
+
66
+ proj_x = -fxX / Z + cx[:, None]
67
+ proj_y = fyY / Z + cy[:, None]
68
+
69
+ return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2)
data_utils/face_tracking/render_3dmm.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import os
5
+ from pytorch3d.structures import Meshes
6
+ from pytorch3d.renderer import (
7
+ look_at_view_transform,
8
+ PerspectiveCameras,
9
+ FoVPerspectiveCameras,
10
+ PointLights,
11
+ DirectionalLights,
12
+ Materials,
13
+ RasterizationSettings,
14
+ MeshRenderer,
15
+ MeshRasterizer,
16
+ SoftPhongShader,
17
+ TexturesUV,
18
+ TexturesVertex,
19
+ blending,
20
+ )
21
+
22
+ from pytorch3d.ops import interpolate_face_attributes
23
+
24
+ from pytorch3d.renderer.blending import (
25
+ BlendParams,
26
+ hard_rgb_blend,
27
+ sigmoid_alpha_blend,
28
+ softmax_rgb_blend,
29
+ )
30
+
31
+
32
+ class SoftSimpleShader(nn.Module):
33
+ """
34
+ Per pixel lighting - the lighting model is applied using the interpolated
35
+ coordinates and normals for each pixel. The blending function returns the
36
+ soft aggregated color using all the faces per pixel.
37
+
38
+ To use the default values, simply initialize the shader with the desired
39
+ device e.g.
40
+
41
+ """
42
+
43
+ def __init__(
44
+ self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
45
+ ):
46
+ super().__init__()
47
+ self.lights = lights if lights is not None else PointLights(device=device)
48
+ self.materials = (
49
+ materials if materials is not None else Materials(device=device)
50
+ )
51
+ self.cameras = cameras
52
+ self.blend_params = blend_params if blend_params is not None else BlendParams()
53
+
54
+ def to(self, device):
55
+ # Manually move to device modules which are not subclasses of nn.Module
56
+ self.cameras = self.cameras.to(device)
57
+ self.materials = self.materials.to(device)
58
+ self.lights = self.lights.to(device)
59
+ return self
60
+
61
+ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
62
+
63
+ texels = meshes.sample_textures(fragments)
64
+ blend_params = kwargs.get("blend_params", self.blend_params)
65
+
66
+ cameras = kwargs.get("cameras", self.cameras)
67
+ if cameras is None:
68
+ msg = "Cameras must be specified either at initialization \
69
+ or in the forward pass of SoftPhongShader"
70
+ raise ValueError(msg)
71
+ znear = kwargs.get("znear", getattr(cameras, "znear", 1.0))
72
+ zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0))
73
+ images = softmax_rgb_blend(
74
+ texels, fragments, blend_params, znear=znear, zfar=zfar
75
+ )
76
+ return images
77
+
78
+
79
+ class Render_3DMM(nn.Module):
80
+ def __init__(
81
+ self,
82
+ focal=1015,
83
+ img_h=500,
84
+ img_w=500,
85
+ batch_size=1,
86
+ device=torch.device("cuda:0"),
87
+ ):
88
+ super(Render_3DMM, self).__init__()
89
+
90
+ self.focal = focal
91
+ self.img_h = img_h
92
+ self.img_w = img_w
93
+ self.device = device
94
+ self.renderer = self.get_render(batch_size)
95
+
96
+ dir_path = os.path.dirname(os.path.realpath(__file__))
97
+ topo_info = np.load(
98
+ os.path.join(dir_path, "3DMM", "topology_info.npy"), allow_pickle=True
99
+ ).item()
100
+ self.tris = torch.as_tensor(topo_info["tris"]).to(self.device)
101
+ self.vert_tris = torch.as_tensor(topo_info["vert_tris"]).to(self.device)
102
+
103
+ def compute_normal(self, geometry):
104
+ vert_1 = torch.index_select(geometry, 1, self.tris[:, 0])
105
+ vert_2 = torch.index_select(geometry, 1, self.tris[:, 1])
106
+ vert_3 = torch.index_select(geometry, 1, self.tris[:, 2])
107
+ nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 2)
108
+ tri_normal = nn.functional.normalize(nnorm, dim=2)
109
+ v_norm = tri_normal[:, self.vert_tris, :].sum(2)
110
+ vert_normal = v_norm / v_norm.norm(dim=2).unsqueeze(2)
111
+ return vert_normal
112
+
113
+ def get_render(self, batch_size=1):
114
+ half_s = self.img_w * 0.5
115
+ R, T = look_at_view_transform(10, 0, 0)
116
+ R = R.repeat(batch_size, 1, 1)
117
+ T = torch.zeros((batch_size, 3), dtype=torch.float32).to(self.device)
118
+
119
+ cameras = FoVPerspectiveCameras(
120
+ device=self.device,
121
+ R=R,
122
+ T=T,
123
+ znear=0.01,
124
+ zfar=20,
125
+ fov=2 * np.arctan(self.img_w // 2 / self.focal) * 180.0 / np.pi,
126
+ )
127
+ lights = PointLights(
128
+ device=self.device,
129
+ location=[[0.0, 0.0, 1e5]],
130
+ ambient_color=[[1, 1, 1]],
131
+ specular_color=[[0.0, 0.0, 0.0]],
132
+ diffuse_color=[[0.0, 0.0, 0.0]],
133
+ )
134
+ sigma = 1e-4
135
+ raster_settings = RasterizationSettings(
136
+ image_size=(self.img_h, self.img_w),
137
+ blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma / 18.0,
138
+ faces_per_pixel=2,
139
+ perspective_correct=False,
140
+ )
141
+ blend_params = blending.BlendParams(background_color=[0, 0, 0])
142
+ renderer = MeshRenderer(
143
+ rasterizer=MeshRasterizer(raster_settings=raster_settings, cameras=cameras),
144
+ shader=SoftSimpleShader(
145
+ lights=lights, blend_params=blend_params, cameras=cameras
146
+ ),
147
+ )
148
+ return renderer.to(self.device)
149
+
150
+ @staticmethod
151
+ def Illumination_layer(face_texture, norm, gamma):
152
+
153
+ n_b, num_vertex, _ = face_texture.size()
154
+ n_v_full = n_b * num_vertex
155
+ gamma = gamma.view(-1, 3, 9).clone()
156
+ gamma[:, :, 0] += 0.8
157
+
158
+ gamma = gamma.permute(0, 2, 1)
159
+
160
+ a0 = np.pi
161
+ a1 = 2 * np.pi / np.sqrt(3.0)
162
+ a2 = 2 * np.pi / np.sqrt(8.0)
163
+ c0 = 1 / np.sqrt(4 * np.pi)
164
+ c1 = np.sqrt(3.0) / np.sqrt(4 * np.pi)
165
+ c2 = 3 * np.sqrt(5.0) / np.sqrt(12 * np.pi)
166
+ d0 = 0.5 / np.sqrt(3.0)
167
+
168
+ Y0 = torch.ones(n_v_full).to(gamma.device).float() * a0 * c0
169
+ norm = norm.view(-1, 3)
170
+ nx, ny, nz = norm[:, 0], norm[:, 1], norm[:, 2]
171
+ arrH = []
172
+
173
+ arrH.append(Y0)
174
+ arrH.append(-a1 * c1 * ny)
175
+ arrH.append(a1 * c1 * nz)
176
+ arrH.append(-a1 * c1 * nx)
177
+ arrH.append(a2 * c2 * nx * ny)
178
+ arrH.append(-a2 * c2 * ny * nz)
179
+ arrH.append(a2 * c2 * d0 * (3 * nz.pow(2) - 1))
180
+ arrH.append(-a2 * c2 * nx * nz)
181
+ arrH.append(a2 * c2 * 0.5 * (nx.pow(2) - ny.pow(2)))
182
+
183
+ H = torch.stack(arrH, 1)
184
+ Y = H.view(n_b, num_vertex, 9)
185
+ lighting = Y.bmm(gamma)
186
+
187
+ face_color = face_texture * lighting
188
+ return face_color
189
+
190
+ def forward(self, rott_geometry, texture, diffuse_sh):
191
+ face_normal = self.compute_normal(rott_geometry)
192
+ face_color = self.Illumination_layer(texture, face_normal, diffuse_sh)
193
+ face_color = TexturesVertex(face_color)
194
+ mesh = Meshes(
195
+ rott_geometry,
196
+ self.tris.float().repeat(rott_geometry.shape[0], 1, 1),
197
+ face_color,
198
+ )
199
+ rendered_img = self.renderer(mesh)
200
+ rendered_img = torch.clamp(rendered_img, 0, 255)
201
+
202
+ return rendered_img
data_utils/face_tracking/render_land.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import render_util
4
+ import geo_transform
5
+ import numpy as np
6
+
7
+
8
+ def compute_tri_normal(geometry, tris):
9
+ geometry = geometry.permute(0, 2, 1)
10
+ tri_1 = tris[:, 0]
11
+ tri_2 = tris[:, 1]
12
+ tri_3 = tris[:, 2]
13
+
14
+ vert_1 = torch.index_select(geometry, 2, tri_1)
15
+ vert_2 = torch.index_select(geometry, 2, tri_2)
16
+ vert_3 = torch.index_select(geometry, 2, tri_3)
17
+
18
+ nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 1)
19
+ normal = nn.functional.normalize(nnorm).permute(0, 2, 1)
20
+ return normal
21
+
22
+
23
+ class Compute_normal_base(torch.autograd.Function):
24
+ @staticmethod
25
+ def forward(ctx, normal):
26
+ (normal_b,) = render_util.normal_base_forward(normal)
27
+ ctx.save_for_backward(normal)
28
+ return normal_b
29
+
30
+ @staticmethod
31
+ def backward(ctx, grad_normal_b):
32
+ (normal,) = ctx.saved_tensors
33
+ (grad_normal,) = render_util.normal_base_backward(grad_normal_b, normal)
34
+ return grad_normal
35
+
36
+
37
+ class Normal_Base(torch.nn.Module):
38
+ def __init__(self):
39
+ super(Normal_Base, self).__init__()
40
+
41
+ def forward(self, normal):
42
+ return Compute_normal_base.apply(normal)
43
+
44
+
45
+ def preprocess_render(geometry, euler, trans, cam, tris, vert_tris, ori_img):
46
+ point_num = geometry.shape[1]
47
+ rott_geo = geo_transform.euler_trans_geo(geometry, euler, trans)
48
+ proj_geo = geo_transform.proj_geo(rott_geo, cam)
49
+ rot_tri_normal = compute_tri_normal(rott_geo, tris)
50
+ rot_vert_normal = torch.index_select(rot_tri_normal, 1, vert_tris)
51
+ is_visible = -torch.bmm(
52
+ rot_vert_normal.reshape(-1, 1, 3),
53
+ nn.functional.normalize(rott_geo.reshape(-1, 3, 1)),
54
+ ).reshape(-1, point_num)
55
+ is_visible[is_visible < 0.01] = -1
56
+ pixel_valid = torch.zeros(
57
+ (ori_img.shape[0], ori_img.shape[1] * ori_img.shape[2]),
58
+ dtype=torch.float32,
59
+ device=ori_img.device,
60
+ )
61
+ return rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid
62
+
63
+
64
+ class Render_Face(torch.autograd.Function):
65
+ @staticmethod
66
+ def forward(
67
+ ctx, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid
68
+ ):
69
+ batch_size, h, w, _ = ori_img.shape
70
+ ori_img = ori_img.view(batch_size, -1, 3)
71
+ ori_size = torch.cat(
72
+ (
73
+ torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
74
+ * h,
75
+ torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
76
+ * w,
77
+ ),
78
+ dim=1,
79
+ ).view(-1)
80
+ tri_index, tri_coord, render, real = render_util.render_face_forward(
81
+ proj_geo, ori_img, ori_size, texture, nbl, is_visible, tri_inds, pixel_valid
82
+ )
83
+ ctx.save_for_backward(
84
+ ori_img, ori_size, proj_geo, texture, nbl, tri_inds, tri_index, tri_coord
85
+ )
86
+ return render, real
87
+
88
+ @staticmethod
89
+ def backward(ctx, grad_render, grad_real):
90
+ (
91
+ ori_img,
92
+ ori_size,
93
+ proj_geo,
94
+ texture,
95
+ nbl,
96
+ tri_inds,
97
+ tri_index,
98
+ tri_coord,
99
+ ) = ctx.saved_tensors
100
+ grad_proj_geo, grad_texture, grad_nbl = render_util.render_face_backward(
101
+ grad_render,
102
+ grad_real,
103
+ ori_img,
104
+ ori_size,
105
+ proj_geo,
106
+ texture,
107
+ nbl,
108
+ tri_inds,
109
+ tri_index,
110
+ tri_coord,
111
+ )
112
+ return grad_proj_geo, grad_texture, grad_nbl, None, None, None, None
113
+
114
+
115
+ class Render_RGB(nn.Module):
116
+ def __init__(self):
117
+ super(Render_RGB, self).__init__()
118
+
119
+ def forward(
120
+ self, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid
121
+ ):
122
+ return Render_Face.apply(
123
+ proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid
124
+ )
125
+
126
+
127
+ def cal_land(proj_geo, is_visible, lands_info, land_num):
128
+ (land_index,) = render_util.update_contour(lands_info, is_visible, land_num)
129
+ proj_land = torch.index_select(proj_geo.reshape(-1, 3), 0, land_index)[
130
+ :, :2
131
+ ].reshape(-1, land_num, 2)
132
+ return proj_land
133
+
134
+
135
+ class Render_Land(nn.Module):
136
+ def __init__(self):
137
+ super(Render_Land, self).__init__()
138
+ lands_info = np.loadtxt("../data/3DMM/lands_info.txt", dtype=np.int32)
139
+ self.lands_info = torch.as_tensor(lands_info).cuda()
140
+ tris = np.loadtxt("../data/3DMM/tris.txt", dtype=np.int64)
141
+ self.tris = torch.as_tensor(tris).cuda() - 1
142
+ vert_tris = np.loadtxt("../data/3DMM/vert_tris.txt", dtype=np.int64)
143
+ self.vert_tris = torch.as_tensor(vert_tris).cuda()
144
+ self.normal_baser = Normal_Base().cuda()
145
+ self.renderer = Render_RGB().cuda()
146
+
147
+ def render_mesh(self, geometry, euler, trans, cam, ori_img, light):
148
+ batch_size, h, w, _ = ori_img.shape
149
+ ori_img = ori_img.view(batch_size, -1, 3)
150
+ ori_size = torch.cat(
151
+ (
152
+ torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
153
+ * h,
154
+ torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
155
+ * w,
156
+ ),
157
+ dim=1,
158
+ ).view(-1)
159
+ rott_geo, proj_geo, rot_tri_normal, _, _ = preprocess_render(
160
+ geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img
161
+ )
162
+ tri_nb = self.normal_baser(rot_tri_normal.contiguous())
163
+ nbl = torch.bmm(
164
+ tri_nb, (light.reshape(-1, 9, 3))[:, :, 0].unsqueeze(-1).repeat(1, 1, 3)
165
+ )
166
+ texture = torch.ones_like(geometry) * 200
167
+ (render,) = render_util.render_mesh(
168
+ proj_geo, ori_img, ori_size, texture, nbl, self.tris
169
+ )
170
+ return render.view(batch_size, h, w, 3).byte()
171
+
172
+ def cal_loss_rgb(self, geometry, euler, trans, cam, ori_img, light, texture, lands):
173
+ rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid = preprocess_render(
174
+ geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img
175
+ )
176
+ tri_nb = self.normal_baser(rot_tri_normal.contiguous())
177
+ nbl = torch.bmm(tri_nb, light.reshape(-1, 9, 3))
178
+ render, real = self.renderer(
179
+ proj_geo, texture, nbl, ori_img, is_visible, self.tris, pixel_valid
180
+ )
181
+ proj_land = cal_land(proj_geo, is_visible, self.lands_info, lands.shape[1])
182
+ col_minus = torch.norm((render - real).reshape(-1, 3), dim=1).reshape(
183
+ ori_img.shape[0], -1
184
+ )
185
+ col_dis = torch.mean(col_minus * pixel_valid) / (
186
+ torch.mean(pixel_valid) + 0.00001
187
+ )
188
+ land_dists = torch.norm((proj_land - lands).reshape(-1, 2), dim=1).reshape(
189
+ ori_img.shape[0], -1
190
+ )
191
+ lan_dis = torch.mean(land_dists)
192
+ return col_dis, lan_dis
data_utils/face_tracking/util.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def compute_tri_normal(geometry, tris):
7
+ tri_1 = tris[:, 0]
8
+ tri_2 = tris[:, 1]
9
+ tri_3 = tris[:, 2]
10
+ vert_1 = torch.index_select(geometry, 1, tri_1)
11
+ vert_2 = torch.index_select(geometry, 1, tri_2)
12
+ vert_3 = torch.index_select(geometry, 1, tri_3)
13
+ nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 2)
14
+ normal = nn.functional.normalize(nnorm)
15
+ return normal
16
+
17
+
18
+ def euler2rot(euler_angle):
19
+ batch_size = euler_angle.shape[0]
20
+ theta = euler_angle[:, 0].reshape(-1, 1, 1)
21
+ phi = euler_angle[:, 1].reshape(-1, 1, 1)
22
+ psi = euler_angle[:, 2].reshape(-1, 1, 1)
23
+ one = torch.ones(batch_size, 1, 1).to(euler_angle.device)
24
+ zero = torch.zeros(batch_size, 1, 1).to(euler_angle.device)
25
+ rot_x = torch.cat(
26
+ (
27
+ torch.cat((one, zero, zero), 1),
28
+ torch.cat((zero, theta.cos(), theta.sin()), 1),
29
+ torch.cat((zero, -theta.sin(), theta.cos()), 1),
30
+ ),
31
+ 2,
32
+ )
33
+ rot_y = torch.cat(
34
+ (
35
+ torch.cat((phi.cos(), zero, -phi.sin()), 1),
36
+ torch.cat((zero, one, zero), 1),
37
+ torch.cat((phi.sin(), zero, phi.cos()), 1),
38
+ ),
39
+ 2,
40
+ )
41
+ rot_z = torch.cat(
42
+ (
43
+ torch.cat((psi.cos(), -psi.sin(), zero), 1),
44
+ torch.cat((psi.sin(), psi.cos(), zero), 1),
45
+ torch.cat((zero, zero, one), 1),
46
+ ),
47
+ 2,
48
+ )
49
+ return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
50
+
51
+
52
+ def rot_trans_pts(geometry, rot, trans):
53
+ rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans[:, :, None]
54
+ return rott_geo.permute(0, 2, 1)
55
+
56
+
57
+ def cal_lap_loss(tensor_list, weight_list):
58
+ lap_kernel = (
59
+ torch.Tensor((-0.5, 1.0, -0.5))
60
+ .unsqueeze(0)
61
+ .unsqueeze(0)
62
+ .float()
63
+ .to(tensor_list[0].device)
64
+ )
65
+ loss_lap = 0
66
+ for i in range(len(tensor_list)):
67
+ in_tensor = tensor_list[i]
68
+ in_tensor = in_tensor.view(-1, 1, in_tensor.shape[-1])
69
+ out_tensor = F.conv1d(in_tensor, lap_kernel)
70
+ loss_lap += torch.mean(out_tensor ** 2) * weight_list[i]
71
+ return loss_lap
72
+
73
+
74
+ def proj_pts(rott_geo, focal_length, cxy):
75
+ cx, cy = cxy[0], cxy[1]
76
+ X = rott_geo[:, :, 0]
77
+ Y = rott_geo[:, :, 1]
78
+ Z = rott_geo[:, :, 2]
79
+ fxX = focal_length * X
80
+ fyY = focal_length * Y
81
+ proj_x = -fxX / Z + cx
82
+ proj_y = fyY / Z + cy
83
+ return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2)
84
+
85
+
86
+ def forward_rott(geometry, euler_angle, trans):
87
+ rot = euler2rot(euler_angle)
88
+ rott_geo = rot_trans_pts(geometry, rot, trans)
89
+ return rott_geo
90
+
91
+
92
+ def forward_transform(geometry, euler_angle, trans, focal_length, cxy):
93
+ rot = euler2rot(euler_angle)
94
+ rott_geo = rot_trans_pts(geometry, rot, trans)
95
+ proj_geo = proj_pts(rott_geo, focal_length, cxy)
96
+ return proj_geo
97
+
98
+
99
+ def cal_lan_loss(proj_lan, gt_lan):
100
+ return torch.mean((proj_lan - gt_lan) ** 2)
101
+
102
+
103
+ def cal_col_loss(pred_img, gt_img, img_mask):
104
+ pred_img = pred_img.float()
105
+ # loss = torch.sqrt(torch.sum(torch.square(pred_img - gt_img), 3))*img_mask/255
106
+ loss = (torch.sum(torch.square(pred_img - gt_img), 3)) * img_mask / 255
107
+ loss = torch.sum(loss, dim=(1, 2)) / torch.sum(img_mask, dim=(1, 2))
108
+ loss = torch.mean(loss)
109
+ return loss
data_utils/hubert.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Wav2Vec2Processor, HubertModel
2
+ import soundfile as sf
3
+ import numpy as np
4
+ import torch
5
+
6
+ print("Loading the Wav2Vec2 Processor...")
7
+ wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
8
+ print("Loading the HuBERT Model...")
9
+ hubert_model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
10
+
11
+
12
+ def get_hubert_from_16k_wav(wav_16k_name):
13
+ speech_16k, _ = sf.read(wav_16k_name)
14
+ hubert = get_hubert_from_16k_speech(speech_16k)
15
+ return hubert
16
+
17
+ @torch.no_grad()
18
+ def get_hubert_from_16k_speech(speech, device="cuda:0"):
19
+ global hubert_model
20
+ hubert_model = hubert_model.to(device)
21
+ if speech.ndim ==2:
22
+ speech = speech[:, 0] # [T, 2] ==> [T,]
23
+ input_values_all = wav2vec2_processor(speech, return_tensors="pt", sampling_rate=16000).input_values # [1, T]
24
+ input_values_all = input_values_all.to(device)
25
+ # For long audio sequence, due to the memory limitation, we cannot process them in one run
26
+ # HuBERT process the wav with a CNN of stride [5,2,2,2,2,2], making a stride of 320
27
+ # Besides, the kernel is [10,3,3,3,3,2,2], making 400 a fundamental unit to get 1 time step.
28
+ # So the CNN is euqal to a big Conv1D with kernel k=400 and stride s=320
29
+ # We have the equation to calculate out time step: T = floor((t-k)/s)
30
+ # To prevent overlap, we set each clip length of (K+S*(N-1)), where N is the expected length T of this clip
31
+ # The start point of next clip should roll back with a length of (kernel-stride) so it is stride * N
32
+ kernel = 400
33
+ stride = 320
34
+ clip_length = stride * 1000
35
+ num_iter = input_values_all.shape[1] // clip_length
36
+ expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride
37
+ res_lst = []
38
+ for i in range(num_iter):
39
+ if i == 0:
40
+ start_idx = 0
41
+ end_idx = clip_length - stride + kernel
42
+ else:
43
+ start_idx = clip_length * i
44
+ end_idx = start_idx + (clip_length - stride + kernel)
45
+ input_values = input_values_all[:, start_idx: end_idx]
46
+ hidden_states = hubert_model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]
47
+ res_lst.append(hidden_states[0])
48
+ if num_iter > 0:
49
+ input_values = input_values_all[:, clip_length * num_iter:]
50
+ else:
51
+ input_values = input_values_all
52
+ # if input_values.shape[1] != 0:
53
+ if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it
54
+ hidden_states = hubert_model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]
55
+ res_lst.append(hidden_states[0])
56
+ ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024]
57
+ # assert ret.shape[0] == expected_T
58
+ assert abs(ret.shape[0] - expected_T) <= 1
59
+ if ret.shape[0] < expected_T:
60
+ ret = torch.nn.functional.pad(ret, (0,0,0,expected_T-ret.shape[0]))
61
+ else:
62
+ ret = ret[:expected_T]
63
+ return ret
64
+
65
+ def make_even_first_dim(tensor):
66
+ size = list(tensor.size())
67
+ if size[0] % 2 == 1:
68
+ size[0] -= 1
69
+ return tensor[:size[0]]
70
+ return tensor
71
+
72
+ import soundfile as sf
73
+ import numpy as np
74
+ import torch
75
+ from argparse import ArgumentParser
76
+ import librosa
77
+
78
+ parser = ArgumentParser()
79
+ parser.add_argument('--wav', type=str, help='')
80
+ args = parser.parse_args()
81
+
82
+ wav_name = args.wav
83
+
84
+ speech, sr = sf.read(wav_name)
85
+ speech_16k = librosa.resample(speech, orig_sr=sr, target_sr=16000)
86
+ print("SR: {} to {}".format(sr, 16000))
87
+ # print(speech.shape, speech_16k.shape)
88
+
89
+ hubert_hidden = get_hubert_from_16k_speech(speech_16k)
90
+ hubert_hidden = make_even_first_dim(hubert_hidden).reshape(-1, 2, 1024)
91
+ np.save(wav_name.replace('.wav', '_hu.npy'), hubert_hidden.detach().numpy())
92
+ print(hubert_hidden.detach().numpy().shape)
data_utils/process.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import tqdm
4
+ import json
5
+ import argparse
6
+ import cv2
7
+ import numpy as np
8
+
9
+ def extract_audio(path, out_path, sample_rate=16000):
10
+
11
+ print(f'[INFO] ===== extract audio from {path} to {out_path} =====')
12
+ cmd = f'ffmpeg -i {path} -f wav -ar {sample_rate} {out_path}'
13
+ os.system(cmd)
14
+ print(f'[INFO] ===== extracted audio =====')
15
+
16
+
17
+ def extract_audio_features(path, mode='wav2vec'):
18
+
19
+ print(f'[INFO] ===== extract audio labels for {path} =====')
20
+ if mode == 'wav2vec':
21
+ cmd = f'python nerf/asr.py --wav {path} --save_feats'
22
+ else: # deepspeech
23
+ cmd = f'python data_utils/deepspeech_features/extract_ds_features.py --input {path}'
24
+ os.system(cmd)
25
+ print(f'[INFO] ===== extracted audio labels =====')
26
+
27
+
28
+
29
+ def extract_images(path, out_path, fps=25):
30
+
31
+ print(f'[INFO] ===== extract images from {path} to {out_path} =====')
32
+ cmd = f'ffmpeg -i {path} -vf fps={fps} -qmin 1 -q:v 1 -start_number 0 {os.path.join(out_path, "%d.jpg")}'
33
+ os.system(cmd)
34
+ print(f'[INFO] ===== extracted images =====')
35
+
36
+
37
+ def extract_semantics(ori_imgs_dir, parsing_dir):
38
+
39
+ print(f'[INFO] ===== extract semantics from {ori_imgs_dir} to {parsing_dir} =====')
40
+ cmd = f'python data_utils/face_parsing/test.py --respath={parsing_dir} --imgpath={ori_imgs_dir}'
41
+ os.system(cmd)
42
+ print(f'[INFO] ===== extracted semantics =====')
43
+
44
+
45
+ def extract_landmarks(ori_imgs_dir):
46
+
47
+ print(f'[INFO] ===== extract face landmarks from {ori_imgs_dir} =====')
48
+
49
+ import face_alignment
50
+ try:
51
+ fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)
52
+ except:
53
+ fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False)
54
+ image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
55
+ for image_path in tqdm.tqdm(image_paths):
56
+ input = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
57
+ input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
58
+ preds = fa.get_landmarks(input)
59
+ if len(preds) > 0:
60
+ lands = preds[0].reshape(-1, 2)[:,:2]
61
+ np.savetxt(image_path.replace('jpg', 'lms'), lands, '%f')
62
+ del fa
63
+ print(f'[INFO] ===== extracted face landmarks =====')
64
+
65
+
66
+ def extract_background(base_dir, ori_imgs_dir):
67
+
68
+ print(f'[INFO] ===== extract background image from {ori_imgs_dir} =====')
69
+
70
+ from sklearn.neighbors import NearestNeighbors
71
+
72
+ image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
73
+ # only use 1/20 image_paths
74
+ image_paths = image_paths[::20]
75
+ # read one image to get H/W
76
+ tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3]
77
+ h, w = tmp_image.shape[:2]
78
+
79
+ # nearest neighbors
80
+ all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose()
81
+ distss = []
82
+ for image_path in tqdm.tqdm(image_paths):
83
+ parse_img = cv2.imread(image_path.replace('ori_imgs', 'parsing').replace('.jpg', '.png'))
84
+ bg = (parse_img[..., 0] == 255) & (parse_img[..., 1] == 255) & (parse_img[..., 2] == 255)
85
+ fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0)
86
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
87
+ dists, _ = nbrs.kneighbors(all_xys)
88
+ distss.append(dists)
89
+
90
+ distss = np.stack(distss)
91
+ max_dist = np.max(distss, 0)
92
+ max_id = np.argmax(distss, 0)
93
+
94
+ bc_pixs = max_dist > 5
95
+ bc_pixs_id = np.nonzero(bc_pixs)
96
+ bc_ids = max_id[bc_pixs]
97
+
98
+ imgs = []
99
+ num_pixs = distss.shape[1]
100
+ for image_path in image_paths:
101
+ img = cv2.imread(image_path)
102
+ imgs.append(img)
103
+ imgs = np.stack(imgs).reshape(-1, num_pixs, 3)
104
+
105
+ bc_img = np.zeros((h*w, 3), dtype=np.uint8)
106
+ bc_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :]
107
+ bc_img = bc_img.reshape(h, w, 3)
108
+
109
+ max_dist = max_dist.reshape(h, w)
110
+ bc_pixs = max_dist > 5
111
+ bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose()
112
+ fg_xys = np.stack(np.nonzero(bc_pixs)).transpose()
113
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
114
+ distances, indices = nbrs.kneighbors(bg_xys)
115
+ bg_fg_xys = fg_xys[indices[:, 0]]
116
+ bc_img[bg_xys[:, 0], bg_xys[:, 1], :] = bc_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
117
+
118
+ cv2.imwrite(os.path.join(base_dir, 'bc.jpg'), bc_img)
119
+
120
+ print(f'[INFO] ===== extracted background image =====')
121
+
122
+
123
+ def extract_torso_and_gt(base_dir, ori_imgs_dir):
124
+
125
+ print(f'[INFO] ===== extract torso and gt images for {base_dir} =====')
126
+
127
+ from scipy.ndimage import binary_erosion, binary_dilation
128
+
129
+ # load bg
130
+ bg_image = cv2.imread(os.path.join(base_dir, 'bc.jpg'), cv2.IMREAD_UNCHANGED)
131
+
132
+ image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
133
+
134
+ for image_path in tqdm.tqdm(image_paths):
135
+ # read ori image
136
+ ori_image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
137
+
138
+ # read semantics
139
+ seg = cv2.imread(image_path.replace('ori_imgs', 'parsing').replace('.jpg', '.png'))
140
+ head_part = (seg[..., 0] == 255) & (seg[..., 1] == 0) & (seg[..., 2] == 0)
141
+ neck_part = (seg[..., 0] == 0) & (seg[..., 1] == 255) & (seg[..., 2] == 0)
142
+ torso_part = (seg[..., 0] == 0) & (seg[..., 1] == 0) & (seg[..., 2] == 255)
143
+ bg_part = (seg[..., 0] == 255) & (seg[..., 1] == 255) & (seg[..., 2] == 255)
144
+
145
+ # get gt image
146
+ gt_image = ori_image.copy()
147
+ gt_image[bg_part] = bg_image[bg_part]
148
+ cv2.imwrite(image_path.replace('ori_imgs', 'gt_imgs'), gt_image)
149
+
150
+ # get torso image
151
+ torso_image = gt_image.copy() # rgb
152
+ torso_image[head_part] = bg_image[head_part]
153
+ torso_alpha = 255 * np.ones((gt_image.shape[0], gt_image.shape[1], 1), dtype=np.uint8) # alpha
154
+
155
+ # torso part "vertical" in-painting...
156
+ L = 8 + 1
157
+ torso_coords = np.stack(np.nonzero(torso_part), axis=-1) # [M, 2]
158
+ # lexsort: sort 2D coords first by y then by x,
159
+ # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
160
+ inds = np.lexsort((torso_coords[:, 0], torso_coords[:, 1]))
161
+ torso_coords = torso_coords[inds]
162
+ # choose the top pixel for each column
163
+ u, uid, ucnt = np.unique(torso_coords[:, 1], return_index=True, return_counts=True)
164
+ top_torso_coords = torso_coords[uid] # [m, 2]
165
+ # only keep top-is-head pixels
166
+ top_torso_coords_up = top_torso_coords.copy() - np.array([1, 0])
167
+ mask = head_part[tuple(top_torso_coords_up.T)]
168
+ if mask.any():
169
+ top_torso_coords = top_torso_coords[mask]
170
+ # get the color
171
+ top_torso_colors = gt_image[tuple(top_torso_coords.T)] # [m, 3]
172
+ # construct inpaint coords (vertically up, or minus in x)
173
+ inpaint_torso_coords = top_torso_coords[None].repeat(L, 0) # [L, m, 2]
174
+ inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
175
+ inpaint_torso_coords += inpaint_offsets
176
+ inpaint_torso_coords = inpaint_torso_coords.reshape(-1, 2) # [Lm, 2]
177
+ inpaint_torso_colors = top_torso_colors[None].repeat(L, 0) # [L, m, 3]
178
+ darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
179
+ inpaint_torso_colors = (inpaint_torso_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
180
+ # set color
181
+ torso_image[tuple(inpaint_torso_coords.T)] = inpaint_torso_colors
182
+
183
+ inpaint_torso_mask = np.zeros_like(torso_image[..., 0]).astype(bool)
184
+ inpaint_torso_mask[tuple(inpaint_torso_coords.T)] = True
185
+ else:
186
+ inpaint_torso_mask = None
187
+
188
+
189
+ # neck part "vertical" in-painting...
190
+ push_down = 4
191
+ L = 48 + push_down + 1
192
+
193
+ neck_part = binary_dilation(neck_part, structure=np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=bool), iterations=3)
194
+
195
+ neck_coords = np.stack(np.nonzero(neck_part), axis=-1) # [M, 2]
196
+ # lexsort: sort 2D coords first by y then by x,
197
+ # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
198
+ inds = np.lexsort((neck_coords[:, 0], neck_coords[:, 1]))
199
+ neck_coords = neck_coords[inds]
200
+ # choose the top pixel for each column
201
+ u, uid, ucnt = np.unique(neck_coords[:, 1], return_index=True, return_counts=True)
202
+ top_neck_coords = neck_coords[uid] # [m, 2]
203
+ # only keep top-is-head pixels
204
+ top_neck_coords_up = top_neck_coords.copy() - np.array([1, 0])
205
+ mask = head_part[tuple(top_neck_coords_up.T)]
206
+
207
+ top_neck_coords = top_neck_coords[mask]
208
+ # push these top down for 4 pixels to make the neck inpainting more natural...
209
+ offset_down = np.minimum(ucnt[mask] - 1, push_down)
210
+ top_neck_coords += np.stack([offset_down, np.zeros_like(offset_down)], axis=-1)
211
+ # get the color
212
+ top_neck_colors = gt_image[tuple(top_neck_coords.T)] # [m, 3]
213
+ # construct inpaint coords (vertically up, or minus in x)
214
+ inpaint_neck_coords = top_neck_coords[None].repeat(L, 0) # [L, m, 2]
215
+ inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
216
+ inpaint_neck_coords += inpaint_offsets
217
+ inpaint_neck_coords = inpaint_neck_coords.reshape(-1, 2) # [Lm, 2]
218
+ inpaint_neck_colors = top_neck_colors[None].repeat(L, 0) # [L, m, 3]
219
+ darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
220
+ inpaint_neck_colors = (inpaint_neck_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
221
+ # set color
222
+ torso_image[tuple(inpaint_neck_coords.T)] = inpaint_neck_colors
223
+
224
+ # apply blurring to the inpaint area to avoid vertical-line artifects...
225
+ inpaint_mask = np.zeros_like(torso_image[..., 0]).astype(bool)
226
+ inpaint_mask[tuple(inpaint_neck_coords.T)] = True
227
+
228
+ blur_img = torso_image.copy()
229
+ blur_img = cv2.GaussianBlur(blur_img, (5, 5), cv2.BORDER_DEFAULT)
230
+
231
+ torso_image[inpaint_mask] = blur_img[inpaint_mask]
232
+
233
+ # set mask
234
+ mask = (neck_part | torso_part | inpaint_mask)
235
+ if inpaint_torso_mask is not None:
236
+ mask = mask | inpaint_torso_mask
237
+ torso_image[~mask] = 0
238
+ torso_alpha[~mask] = 0
239
+
240
+ cv2.imwrite(image_path.replace('ori_imgs', 'torso_imgs').replace('.jpg', '.png'), np.concatenate([torso_image, torso_alpha], axis=-1))
241
+
242
+ print(f'[INFO] ===== extracted torso and gt images =====')
243
+
244
+
245
+ def face_tracking(ori_imgs_dir):
246
+
247
+ print(f'[INFO] ===== perform face tracking =====')
248
+
249
+ image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
250
+
251
+ # read one image to get H/W
252
+ tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3]
253
+ h, w = tmp_image.shape[:2]
254
+
255
+ cmd = f'python data_utils/face_tracking/face_tracker.py --path={ori_imgs_dir} --img_h={h} --img_w={w} --frame_num={len(image_paths)}'
256
+
257
+ os.system(cmd)
258
+
259
+ print(f'[INFO] ===== finished face tracking =====')
260
+
261
+
262
+ def save_transforms(base_dir, ori_imgs_dir):
263
+ print(f'[INFO] ===== save transforms =====')
264
+
265
+ import torch
266
+
267
+ image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
268
+
269
+ # read one image to get H/W
270
+ tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3]
271
+ h, w = tmp_image.shape[:2]
272
+
273
+ params_dict = torch.load(os.path.join(base_dir, 'track_params.pt'))
274
+ focal_len = params_dict['focal']
275
+ euler_angle = params_dict['euler']
276
+ trans = params_dict['trans'] / 10.0
277
+ valid_num = euler_angle.shape[0]
278
+
279
+ def euler2rot(euler_angle):
280
+ batch_size = euler_angle.shape[0]
281
+ theta = euler_angle[:, 0].reshape(-1, 1, 1)
282
+ phi = euler_angle[:, 1].reshape(-1, 1, 1)
283
+ psi = euler_angle[:, 2].reshape(-1, 1, 1)
284
+ one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device)
285
+ zero = torch.zeros((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device)
286
+ rot_x = torch.cat((
287
+ torch.cat((one, zero, zero), 1),
288
+ torch.cat((zero, theta.cos(), theta.sin()), 1),
289
+ torch.cat((zero, -theta.sin(), theta.cos()), 1),
290
+ ), 2)
291
+ rot_y = torch.cat((
292
+ torch.cat((phi.cos(), zero, -phi.sin()), 1),
293
+ torch.cat((zero, one, zero), 1),
294
+ torch.cat((phi.sin(), zero, phi.cos()), 1),
295
+ ), 2)
296
+ rot_z = torch.cat((
297
+ torch.cat((psi.cos(), -psi.sin(), zero), 1),
298
+ torch.cat((psi.sin(), psi.cos(), zero), 1),
299
+ torch.cat((zero, zero, one), 1)
300
+ ), 2)
301
+ return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
302
+
303
+
304
+ # train_val_split = int(valid_num*0.5)
305
+ # train_val_split = valid_num - 25 * 20 # take the last 20s as valid set.
306
+ train_val_split = int(valid_num * 10 / 11)
307
+
308
+ train_ids = torch.arange(0, train_val_split)
309
+ val_ids = torch.arange(train_val_split, valid_num)
310
+
311
+ rot = euler2rot(euler_angle)
312
+ rot_inv = rot.permute(0, 2, 1)
313
+ trans_inv = -torch.bmm(rot_inv, trans.unsqueeze(2))
314
+
315
+ pose = torch.eye(4, dtype=torch.float32)
316
+ save_ids = ['train', 'val']
317
+ train_val_ids = [train_ids, val_ids]
318
+ mean_z = -float(torch.mean(trans[:, 2]).item())
319
+
320
+ for split in range(2):
321
+ transform_dict = dict()
322
+ transform_dict['focal_len'] = float(focal_len[0])
323
+ transform_dict['cx'] = float(w/2.0)
324
+ transform_dict['cy'] = float(h/2.0)
325
+ transform_dict['frames'] = []
326
+ ids = train_val_ids[split]
327
+ save_id = save_ids[split]
328
+
329
+ for i in ids:
330
+ i = i.item()
331
+ frame_dict = dict()
332
+ frame_dict['img_id'] = i
333
+ frame_dict['aud_id'] = i
334
+
335
+ pose[:3, :3] = rot_inv[i]
336
+ pose[:3, 3] = trans_inv[i, :, 0]
337
+
338
+ frame_dict['transform_matrix'] = pose.numpy().tolist()
339
+
340
+ transform_dict['frames'].append(frame_dict)
341
+
342
+ with open(os.path.join(base_dir, 'transforms_' + save_id + '.json'), 'w') as fp:
343
+ json.dump(transform_dict, fp, indent=2, separators=(',', ': '))
344
+
345
+ print(f'[INFO] ===== finished saving transforms =====')
346
+
347
+
348
+ if __name__ == '__main__':
349
+ parser = argparse.ArgumentParser()
350
+ parser.add_argument('path', type=str, help="path to video file")
351
+ parser.add_argument('--task', type=int, default=-1, help="-1 means all")
352
+ parser.add_argument('--asr', type=str, default='deepspeech', help="wav2vec or deepspeech")
353
+
354
+ opt = parser.parse_args()
355
+
356
+ base_dir = os.path.dirname(opt.path)
357
+
358
+ wav_path = os.path.join(base_dir, 'aud.wav')
359
+ ori_imgs_dir = os.path.join(base_dir, 'ori_imgs')
360
+ parsing_dir = os.path.join(base_dir, 'parsing')
361
+ gt_imgs_dir = os.path.join(base_dir, 'gt_imgs')
362
+ torso_imgs_dir = os.path.join(base_dir, 'torso_imgs')
363
+
364
+ os.makedirs(ori_imgs_dir, exist_ok=True)
365
+ os.makedirs(parsing_dir, exist_ok=True)
366
+ os.makedirs(gt_imgs_dir, exist_ok=True)
367
+ os.makedirs(torso_imgs_dir, exist_ok=True)
368
+
369
+
370
+ # extract audio
371
+ if opt.task == -1 or opt.task == 1:
372
+ extract_audio(opt.path, wav_path)
373
+
374
+ # extract audio features
375
+ if opt.task == -1 or opt.task == 2:
376
+ extract_audio_features(wav_path, mode=opt.asr)
377
+
378
+ # extract images
379
+ if opt.task == -1 or opt.task == 3:
380
+ extract_images(opt.path, ori_imgs_dir)
381
+
382
+ # face parsing
383
+ if opt.task == -1 or opt.task == 4:
384
+ extract_semantics(ori_imgs_dir, parsing_dir)
385
+
386
+ # extract bg
387
+ if opt.task == -1 or opt.task == 5:
388
+ extract_background(base_dir, ori_imgs_dir)
389
+
390
+ # extract torso images and gt_images
391
+ if opt.task == -1 or opt.task == 6:
392
+ extract_torso_and_gt(base_dir, ori_imgs_dir)
393
+
394
+ # extract face landmarks
395
+ if opt.task == -1 or opt.task == 7:
396
+ extract_landmarks(ori_imgs_dir)
397
+
398
+ # face tracking
399
+ if opt.task == -1 or opt.task == 8:
400
+ face_tracking(ori_imgs_dir)
401
+
402
+ # save transforms.json
403
+ if opt.task == -1 or opt.task == 9:
404
+ save_transforms(base_dir, ori_imgs_dir)
405
+
data_utils/wav2mel.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import librosa.filters
3
+ import numpy as np
4
+ from scipy import signal
5
+ from wav2mel_hparams import hparams as hp
6
+ from librosa.core.audio import resample
7
+ import soundfile as sf
8
+
9
+ def load_wav(path, sr):
10
+ return librosa.core.load(path, sr=sr)
11
+
12
+ def preemphasis(wav, k, preemphasize=True):
13
+ if preemphasize:
14
+ return signal.lfilter([1, -k], [1], wav)
15
+ return wav
16
+
17
+ def inv_preemphasis(wav, k, inv_preemphasize=True):
18
+ if inv_preemphasize:
19
+ return signal.lfilter([1], [1, -k], wav)
20
+ return wav
21
+
22
+ def get_hop_size():
23
+ hop_size = hp.hop_size
24
+ if hop_size is None:
25
+ assert hp.frame_shift_ms is not None
26
+ hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
27
+ return hop_size
28
+
29
+ def linearspectrogram(wav):
30
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
31
+ S = _amp_to_db(np.abs(D)) - hp.ref_level_db
32
+
33
+ if hp.signal_normalization:
34
+ return _normalize(S)
35
+ return S
36
+
37
+ def melspectrogram(wav):
38
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
39
+ S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
40
+
41
+ if hp.signal_normalization:
42
+ return _normalize(S)
43
+ return S
44
+
45
+ def _stft(y):
46
+ return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
47
+
48
+ ##########################################################
49
+ #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
50
+ def num_frames(length, fsize, fshift):
51
+ """Compute number of time frames of spectrogram
52
+ """
53
+ pad = (fsize - fshift)
54
+ if length % fshift == 0:
55
+ M = (length + pad * 2 - fsize) // fshift + 1
56
+ else:
57
+ M = (length + pad * 2 - fsize) // fshift + 2
58
+ return M
59
+
60
+
61
+ def pad_lr(x, fsize, fshift):
62
+ """Compute left and right padding
63
+ """
64
+ M = num_frames(len(x), fsize, fshift)
65
+ pad = (fsize - fshift)
66
+ T = len(x) + 2 * pad
67
+ r = (M - 1) * fshift + fsize - T
68
+ return pad, pad + r
69
+ ##########################################################
70
+ #Librosa correct padding
71
+ def librosa_pad_lr(x, fsize, fshift):
72
+ return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
73
+
74
+ # Conversions
75
+ _mel_basis = None
76
+
77
+ def _linear_to_mel(spectogram):
78
+ global _mel_basis
79
+ if _mel_basis is None:
80
+ _mel_basis = _build_mel_basis()
81
+ return np.dot(_mel_basis, spectogram)
82
+
83
+ def _build_mel_basis():
84
+ assert hp.fmax <= hp.sample_rate // 2
85
+ return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels,
86
+ fmin=hp.fmin, fmax=hp.fmax)
87
+
88
+ def _amp_to_db(x):
89
+ min_level = np.exp(hp.min_level_db / 20 * np.log(10))
90
+ return 20 * np.log10(np.maximum(min_level, x))
91
+
92
+ def _db_to_amp(x):
93
+ return np.power(10.0, (x) * 0.05)
94
+
95
+ def _normalize(S):
96
+ if hp.allow_clipping_in_normalization:
97
+ if hp.symmetric_mels:
98
+ return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
99
+ -hp.max_abs_value, hp.max_abs_value)
100
+ else:
101
+ return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
102
+
103
+ assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
104
+ if hp.symmetric_mels:
105
+ return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
106
+ else:
107
+ return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
108
+
109
+ def _denormalize(D):
110
+ if hp.allow_clipping_in_normalization:
111
+ if hp.symmetric_mels:
112
+ return (((np.clip(D, -hp.max_abs_value,
113
+ hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
114
+ + hp.min_level_db)
115
+ else:
116
+ return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
117
+
118
+ if hp.symmetric_mels:
119
+ return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
120
+ else:
121
+ return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
122
+
123
+
124
+
125
+ def wav2mel(wav, sr):
126
+ wav16k = resample(wav, orig_sr=sr, target_sr=16000)
127
+ # print('wav16k', wav16k.shape, wav16k.dtype)
128
+ mel = melspectrogram(wav16k)
129
+ # print('mel', mel.shape, mel.dtype)
130
+ if np.isnan(mel.reshape(-1)).sum() > 0:
131
+ raise ValueError(
132
+ 'Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
133
+ # mel.dtype = np.float32
134
+ mel_chunks = []
135
+ mel_idx_multiplier = 80. / 25
136
+ mel_step_size = 8
137
+ i = start_idx = 0
138
+ while start_idx < len(mel[0]):
139
+ start_idx = int(i * mel_idx_multiplier)
140
+ if start_idx + mel_step_size // 2 > len(mel[0]):
141
+ mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
142
+ elif start_idx - mel_step_size // 2 < 0:
143
+ mel_chunks.append(mel[:, :mel_step_size])
144
+ else:
145
+ mel_chunks.append(mel[:, start_idx - mel_step_size // 2 : start_idx + mel_step_size // 2])
146
+ i += 1
147
+ return mel_chunks
148
+
149
+
150
+
151
+ if __name__ == '__main__':
152
+ import argparse
153
+
154
+ parser = argparse.ArgumentParser()
155
+ parser.add_argument('--wav', type=str, default='')
156
+ parser.add_argument('--save_feats', action='store_true')
157
+
158
+ opt = parser.parse_args()
159
+
160
+ wav, sr = librosa.core.load(opt.wav)
161
+ mel_chunks = np.array(wav2mel(wav.T, sr))
162
+ print(mel_chunks.shape, mel_chunks.transpose(0,2,1).shape)
163
+
164
+ if opt.save_feats:
165
+ save_path = opt.wav.replace('.wav', '_mel.npy')
166
+ np.save(save_path, mel_chunks.transpose(0,2,1))
167
+ print(f"[INFO] saved logits to {save_path}")
data_utils/wav2mel_hparams.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class HParams:
2
+ def __init__(self, **kwargs):
3
+ self.data = {}
4
+
5
+ for key, value in kwargs.items():
6
+ self.data[key] = value
7
+
8
+ def __getattr__(self, key):
9
+ if key not in self.data:
10
+ raise AttributeError("'HParams' object has no attribute %s" % key)
11
+ return self.data[key]
12
+
13
+ def set_hparam(self, key, value):
14
+ self.data[key] = value
15
+
16
+ # Default hyperparameters
17
+ hparams = HParams(
18
+ num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
19
+ # network
20
+ rescale=True, # Whether to rescale audio prior to preprocessing
21
+ rescaling_max=0.9, # Rescaling value
22
+
23
+ # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
24
+ # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
25
+ # Does not work if n_ffit is not multiple of hop_size!!
26
+ use_lws=False,
27
+
28
+ n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
29
+ hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
30
+ win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
31
+ sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i <filename>)
32
+
33
+ frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
34
+
35
+ # Mel and Linear spectrograms normalization/scaling and clipping
36
+ signal_normalization=True,
37
+ # Whether to normalize mel spectrograms to some predefined range (following below parameters)
38
+ allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
39
+ symmetric_mels=True,
40
+ # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
41
+ # faster and cleaner convergence)
42
+ max_abs_value=4.,
43
+ # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
44
+ # be too big to avoid gradient explosion,
45
+ # not too small for fast convergence)
46
+ # Contribution by @begeekmyfriend
47
+ # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
48
+ # levels. Also allows for better G&L phase reconstruction)
49
+ preemphasize=True, # whether to apply filter
50
+ preemphasis=0.97, # filter coefficient.
51
+
52
+ # Limits
53
+ min_level_db=-100,
54
+ ref_level_db=20,
55
+ fmin=65,
56
+ # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
57
+ # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
58
+ fmax=6000, # To be increased/reduced depending on data.
59
+
60
+ ###################### Our training parameters #################################
61
+ img_size=96,
62
+ fps=25,
63
+
64
+ batch_size=16,
65
+ initial_learning_rate=1e-4,
66
+ nepochs=200000000000000000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
67
+ num_workers=16,
68
+ checkpoint_interval=3000,
69
+ eval_interval=3000,
70
+ save_optimizer_state=True,
71
+
72
+ syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence.
73
+ syncnet_batch_size=64,
74
+ syncnet_lr=1e-4,
75
+ syncnet_eval_interval=10000,
76
+ syncnet_checkpoint_interval=10000,
77
+
78
+ disc_wt=0.07,
79
+ disc_initial_learning_rate=1e-4,
80
+ )
data_utils/wav2vec.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from transformers import AutoModelForCTC, AutoProcessor
6
+
7
+ import pyaudio
8
+ import soundfile as sf
9
+ import resampy
10
+
11
+ from queue import Queue
12
+ from threading import Thread, Event
13
+
14
+
15
+ def _read_frame(stream, exit_event, queue, chunk):
16
+
17
+ while True:
18
+ if exit_event.is_set():
19
+ print(f'[INFO] read frame thread ends')
20
+ break
21
+ frame = stream.read(chunk, exception_on_overflow=False)
22
+ frame = np.frombuffer(frame, dtype=np.int16).astype(np.float32) / 32767 # [chunk]
23
+ queue.put(frame)
24
+
25
+ def _play_frame(stream, exit_event, queue, chunk):
26
+
27
+ while True:
28
+ if exit_event.is_set():
29
+ print(f'[INFO] play frame thread ends')
30
+ break
31
+ frame = queue.get()
32
+ frame = (frame * 32767).astype(np.int16).tobytes()
33
+ stream.write(frame, chunk)
34
+
35
+ class ASR:
36
+ def __init__(self, opt):
37
+
38
+ self.opt = opt
39
+
40
+ self.play = opt.asr_play
41
+
42
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
43
+ self.fps = opt.fps # 20 ms per frame
44
+ self.sample_rate = 16000
45
+ self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
46
+ self.mode = 'live' if opt.asr_wav == '' else 'file'
47
+
48
+ if 'esperanto' in self.opt.asr_model:
49
+ self.audio_dim = 44
50
+ elif 'deepspeech' in self.opt.asr_model:
51
+ self.audio_dim = 29
52
+ else:
53
+ self.audio_dim = 32
54
+
55
+ # prepare context cache
56
+ # each segment is (stride_left + ctx + stride_right) * 20ms, latency should be (ctx + stride_right) * 20ms
57
+ self.context_size = opt.m
58
+ self.stride_left_size = opt.l
59
+ self.stride_right_size = opt.r
60
+ self.text = '[START]\n'
61
+ self.terminated = False
62
+ self.frames = []
63
+
64
+ # pad left frames
65
+ if self.stride_left_size > 0:
66
+ self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size)
67
+
68
+
69
+ self.exit_event = Event()
70
+ self.audio_instance = pyaudio.PyAudio()
71
+
72
+ # create input stream
73
+ if self.mode == 'file':
74
+ self.file_stream = self.create_file_stream()
75
+ else:
76
+ # start a background process to read frames
77
+ self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk)
78
+ self.queue = Queue()
79
+ self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk))
80
+
81
+ # play out the audio too...?
82
+ if self.play:
83
+ self.output_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=False, output=True, frames_per_buffer=self.chunk)
84
+ self.output_queue = Queue()
85
+ self.process_play_frame = Thread(target=_play_frame, args=(self.output_stream, self.exit_event, self.output_queue, self.chunk))
86
+
87
+ # current location of audio
88
+ self.idx = 0
89
+
90
+ # create wav2vec model
91
+ print(f'[INFO] loading ASR model {self.opt.asr_model}...')
92
+ self.processor = AutoProcessor.from_pretrained(opt.asr_model)
93
+ self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device)
94
+
95
+ # prepare to save logits
96
+ if self.opt.asr_save_feats:
97
+ self.all_feats = []
98
+
99
+ # the extracted features
100
+ # use a loop queue to efficiently record endless features: [f--t---][-------][-------]
101
+ self.feat_buffer_size = 4
102
+ self.feat_buffer_idx = 0
103
+ self.feat_queue = torch.zeros(self.feat_buffer_size * self.context_size, self.audio_dim, dtype=torch.float32, device=self.device)
104
+
105
+ # TODO: hard coded 16 and 8 window size...
106
+ self.front = self.feat_buffer_size * self.context_size - 8 # fake padding
107
+ self.tail = 8
108
+ # attention window...
109
+ self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding...
110
+
111
+ # warm up steps needed: mid + right + window_size + attention_size
112
+ self.warm_up_steps = self.context_size + self.stride_right_size + 8 + 2 * 3
113
+
114
+ self.listening = False
115
+ self.playing = False
116
+
117
+ def listen(self):
118
+ # start
119
+ if self.mode == 'live' and not self.listening:
120
+ print(f'[INFO] starting read frame thread...')
121
+ self.process_read_frame.start()
122
+ self.listening = True
123
+
124
+ if self.play and not self.playing:
125
+ print(f'[INFO] starting play frame thread...')
126
+ self.process_play_frame.start()
127
+ self.playing = True
128
+
129
+ def stop(self):
130
+
131
+ self.exit_event.set()
132
+
133
+ if self.play:
134
+ self.output_stream.stop_stream()
135
+ self.output_stream.close()
136
+ if self.playing:
137
+ self.process_play_frame.join()
138
+ self.playing = False
139
+
140
+ if self.mode == 'live':
141
+ self.input_stream.stop_stream()
142
+ self.input_stream.close()
143
+ if self.listening:
144
+ self.process_read_frame.join()
145
+ self.listening = False
146
+
147
+
148
+ def __enter__(self):
149
+ return self
150
+
151
+ def __exit__(self, exc_type, exc_value, traceback):
152
+
153
+ self.stop()
154
+
155
+ if self.mode == 'live':
156
+ # live mode: also print the result text.
157
+ self.text += '\n[END]'
158
+ print(self.text)
159
+
160
+ def get_next_feat(self):
161
+ # return a [1/8, 16] window, for the next input to nerf side.
162
+
163
+ while len(self.att_feats) < 8:
164
+ # [------f+++t-----]
165
+ if self.front < self.tail:
166
+ feat = self.feat_queue[self.front:self.tail]
167
+ # [++t-----------f+]
168
+ else:
169
+ feat = torch.cat([self.feat_queue[self.front:], self.feat_queue[:self.tail]], dim=0)
170
+
171
+ self.front = (self.front + 2) % self.feat_queue.shape[0]
172
+ self.tail = (self.tail + 2) % self.feat_queue.shape[0]
173
+
174
+ # print(self.front, self.tail, feat.shape)
175
+
176
+ self.att_feats.append(feat.permute(1, 0))
177
+
178
+ att_feat = torch.stack(self.att_feats, dim=0) # [8, 44, 16]
179
+
180
+ # discard old
181
+ self.att_feats = self.att_feats[1:]
182
+
183
+ return att_feat
184
+
185
+ def run_step(self):
186
+
187
+ if self.terminated:
188
+ return
189
+
190
+ # get a frame of audio
191
+ frame = self.get_audio_frame()
192
+
193
+ # the last frame
194
+ if frame is None:
195
+ # terminate, but always run the network for the left frames
196
+ self.terminated = True
197
+ else:
198
+ self.frames.append(frame)
199
+ # put to output
200
+ if self.play:
201
+ self.output_queue.put(frame)
202
+ # context not enough, do not run network.
203
+ if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
204
+ return
205
+
206
+ inputs = np.concatenate(self.frames) # [N * chunk]
207
+
208
+ # discard the old part to save memory
209
+ if not self.terminated:
210
+ self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
211
+
212
+ logits, labels, text = self.frame_to_text(inputs)
213
+ feats = logits # better lips-sync than labels
214
+
215
+ # save feats
216
+ if self.opt.asr_save_feats:
217
+ self.all_feats.append(feats)
218
+
219
+ # record the feats efficiently.. (no concat, constant memory)
220
+ if not self.terminated:
221
+ start = self.feat_buffer_idx * self.context_size
222
+ end = start + feats.shape[0]
223
+ self.feat_queue[start:end] = feats
224
+ self.feat_buffer_idx = (self.feat_buffer_idx + 1) % self.feat_buffer_size
225
+
226
+ # very naive, just concat the text output.
227
+ if text != '':
228
+ self.text = self.text + ' ' + text
229
+
230
+ # will only run once at ternimation
231
+ if self.terminated:
232
+ self.text += '\n[END]'
233
+ print(self.text)
234
+ if self.opt.asr_save_feats:
235
+ print(f'[INFO] save all feats for training purpose... ')
236
+ feats = torch.cat(self.all_feats, dim=0) # [N, C]
237
+ # print('[INFO] before unfold', feats.shape)
238
+ window_size = 16
239
+ padding = window_size // 2
240
+ feats = feats.view(-1, self.audio_dim).permute(1, 0).contiguous() # [C, M]
241
+ feats = feats.view(1, self.audio_dim, -1, 1) # [1, C, M, 1]
242
+ unfold_feats = F.unfold(feats, kernel_size=(window_size, 1), padding=(padding, 0), stride=(2, 1)) # [1, C * window_size, M / 2 + 1]
243
+ unfold_feats = unfold_feats.view(self.audio_dim, window_size, -1).permute(2, 1, 0).contiguous() # [C, window_size, M / 2 + 1] --> [M / 2 + 1, window_size, C]
244
+ # print('[INFO] after unfold', unfold_feats.shape)
245
+ # save to a npy file
246
+ if 'esperanto' in self.opt.asr_model:
247
+ output_path = self.opt.asr_wav.replace('.wav', '_eo.npy')
248
+ else:
249
+ output_path = self.opt.asr_wav.replace('.wav', '.npy')
250
+ np.save(output_path, unfold_feats.cpu().numpy())
251
+ print(f"[INFO] saved logits to {output_path}")
252
+
253
+ def create_file_stream(self):
254
+
255
+ stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64
256
+ stream = stream.astype(np.float32)
257
+
258
+ if stream.ndim > 1:
259
+ print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
260
+ stream = stream[:, 0]
261
+
262
+ if sample_rate != self.sample_rate:
263
+ print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
264
+ stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
265
+
266
+ print(f'[INFO] loaded audio stream {self.opt.asr_wav}: {stream.shape}')
267
+
268
+ return stream
269
+
270
+
271
+ def create_pyaudio_stream(self):
272
+
273
+ import pyaudio
274
+
275
+ print(f'[INFO] creating live audio stream ...')
276
+
277
+ audio = pyaudio.PyAudio()
278
+
279
+ # get devices
280
+ info = audio.get_host_api_info_by_index(0)
281
+ n_devices = info.get('deviceCount')
282
+
283
+ for i in range(0, n_devices):
284
+ if (audio.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
285
+ name = audio.get_device_info_by_host_api_device_index(0, i).get('name')
286
+ print(f'[INFO] choose audio device {name}, id {i}')
287
+ break
288
+
289
+ # get stream
290
+ stream = audio.open(input_device_index=i,
291
+ format=pyaudio.paInt16,
292
+ channels=1,
293
+ rate=self.sample_rate,
294
+ input=True,
295
+ frames_per_buffer=self.chunk)
296
+
297
+ return audio, stream
298
+
299
+
300
+ def get_audio_frame(self):
301
+
302
+ if self.mode == 'file':
303
+
304
+ if self.idx < self.file_stream.shape[0]:
305
+ frame = self.file_stream[self.idx: self.idx + self.chunk]
306
+ self.idx = self.idx + self.chunk
307
+ return frame
308
+ else:
309
+ return None
310
+
311
+ else:
312
+
313
+ frame = self.queue.get()
314
+ # print(f'[INFO] get frame {frame.shape}')
315
+
316
+ self.idx = self.idx + self.chunk
317
+
318
+ return frame
319
+
320
+
321
+ def frame_to_text(self, frame):
322
+ # frame: [N * 320], N = (context_size + 2 * stride_size)
323
+
324
+ inputs = self.processor(frame, sampling_rate=self.sample_rate, return_tensors="pt", padding=True)
325
+
326
+ with torch.no_grad():
327
+ result = self.model(inputs.input_values.to(self.device))
328
+ logits = result.logits # [1, N - 1, 32]
329
+
330
+ # cut off stride
331
+ left = max(0, self.stride_left_size)
332
+ right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input.
333
+
334
+ # do not cut right if terminated.
335
+ if self.terminated:
336
+ right = logits.shape[1]
337
+
338
+ logits = logits[:, left:right]
339
+
340
+ # print(frame.shape, inputs.input_values.shape, logits.shape)
341
+
342
+ predicted_ids = torch.argmax(logits, dim=-1)
343
+ transcription = self.processor.batch_decode(predicted_ids)[0].lower()
344
+
345
+
346
+ # for esperanto
347
+ # labels = np.array(['ŭ', '»', 'c', 'ĵ', 'ñ', '”', '„', '“', 'ǔ', 'o', 'ĝ', 'm', 'k', 'd', 'a', 'ŝ', 'z', 'i', '«', '—', '‘', 'ĥ', 'f', 'y', 'h', 'j', '|', 'r', 'u', 'ĉ', 's', '–', 'fi', 'l', 'p', '’', 'g', 'v', 't', 'b', 'n', 'e', '[UNK]', '[PAD]'])
348
+
349
+ # labels = np.array([' ', ' ', ' ', '-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z'])
350
+ # print(''.join(labels[predicted_ids[0].detach().cpu().long().numpy()]))
351
+ # print(predicted_ids[0])
352
+ # print(transcription)
353
+
354
+ return logits[0], predicted_ids[0], transcription # [N,]
355
+
356
+
357
+ def run(self):
358
+
359
+ self.listen()
360
+
361
+ while not self.terminated:
362
+ self.run_step()
363
+
364
+ def clear_queue(self):
365
+ # clear the queue, to reduce potential latency...
366
+ print(f'[INFO] clear queue')
367
+ if self.mode == 'live':
368
+ self.queue.queue.clear()
369
+ if self.play:
370
+ self.output_queue.queue.clear()
371
+
372
+ def warm_up(self):
373
+
374
+ self.listen()
375
+
376
+ print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
377
+ t = time.time()
378
+ for _ in range(self.warm_up_steps):
379
+ self.run_step()
380
+ if torch.cuda.is_available():
381
+ torch.cuda.synchronize()
382
+ t = time.time() - t
383
+ print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
384
+
385
+ self.clear_queue()
386
+
387
+
388
+
389
+
390
+ if __name__ == '__main__':
391
+ import argparse
392
+
393
+ parser = argparse.ArgumentParser()
394
+ parser.add_argument('--wav', type=str, default='')
395
+ parser.add_argument('--play', action='store_true', help="play out the audio")
396
+
397
+ parser.add_argument('--model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto')
398
+ # parser.add_argument('--model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
399
+
400
+ parser.add_argument('--save_feats', action='store_true')
401
+ # audio FPS
402
+ parser.add_argument('--fps', type=int, default=50)
403
+ # sliding window left-middle-right length.
404
+ parser.add_argument('-l', type=int, default=10)
405
+ parser.add_argument('-m', type=int, default=50)
406
+ parser.add_argument('-r', type=int, default=10)
407
+
408
+ opt = parser.parse_args()
409
+
410
+ # fix
411
+ opt.asr_wav = opt.wav
412
+ opt.asr_play = opt.play
413
+ opt.asr_model = opt.model
414
+ opt.asr_save_feats = opt.save_feats
415
+
416
+ if 'deepspeech' in opt.asr_model:
417
+ raise ValueError("DeepSpeech features should not use this code to extract...")
418
+
419
+ with ASR(opt) as asr:
420
+ asr.run()
tensorflow-models/deepspeech-0_1_0-b90017e8.pb.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8866cb5098e4950e72912e49aa2ac192f279ecdd60f7721aa570b7a543451085
3
+ size 455423408
torch-hub/2DFAN4-cd938726ad.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd938726adb1f15f361263cce2db9cb820c42585fa8796ec72ce19107f369a46
3
+ size 96316515
torch-hub/s3fd-619a316812.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:619a31681264d3f7f7fc7a16a42cbbe8b23f31a256f75a366e5a1bcd59b33543
3
+ size 89843225