Upload 36 files
Browse files- .gitattributes +1 -0
- data_utils/deepspeech_features/README.md +20 -0
- data_utils/deepspeech_features/deepspeech_features.py +275 -0
- data_utils/deepspeech_features/deepspeech_store.py +172 -0
- data_utils/deepspeech_features/extract_ds_features.py +132 -0
- data_utils/deepspeech_features/extract_wav.py +87 -0
- data_utils/deepspeech_features/fea_win.py +11 -0
- data_utils/face_parsing/79999_iter.pth +3 -0
- data_utils/face_parsing/logger.py +23 -0
- data_utils/face_parsing/model.py +285 -0
- data_utils/face_parsing/resnet.py +109 -0
- data_utils/face_parsing/test.py +98 -0
- data_utils/face_tracking/.DS_Store +0 -0
- data_utils/face_tracking/3DMM/.DS_Store +0 -0
- data_utils/face_tracking/3DMM/01_MorphableModel.mat +3 -0
- data_utils/face_tracking/3DMM/01_MorphableModel.mat.zip +3 -0
- data_utils/face_tracking/3DMM/exp_info.npy +3 -0
- data_utils/face_tracking/3DMM/keys_info.npy +3 -0
- data_utils/face_tracking/3DMM/sub_mesh.obj +0 -0
- data_utils/face_tracking/3DMM/topology_info.npy +3 -0
- data_utils/face_tracking/__init__.py +0 -0
- data_utils/face_tracking/convert_BFM.py +39 -0
- data_utils/face_tracking/data_loader.py +16 -0
- data_utils/face_tracking/face_tracker.py +390 -0
- data_utils/face_tracking/facemodel.py +153 -0
- data_utils/face_tracking/geo_transform.py +69 -0
- data_utils/face_tracking/render_3dmm.py +202 -0
- data_utils/face_tracking/render_land.py +192 -0
- data_utils/face_tracking/util.py +109 -0
- data_utils/hubert.py +92 -0
- data_utils/process.py +405 -0
- data_utils/wav2mel.py +167 -0
- data_utils/wav2mel_hparams.py +80 -0
- data_utils/wav2vec.py +420 -0
- tensorflow-models/deepspeech-0_1_0-b90017e8.pb.zip +3 -0
- torch-hub/2DFAN4-cd938726ad.zip +3 -0
- torch-hub/s3fd-619a316812.pth +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
data_utils/face_tracking/3DMM/01_MorphableModel.mat filter=lfs diff=lfs merge=lfs -text
|
data_utils/deepspeech_features/README.md
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Routines for DeepSpeech features processing
|
2 |
+
Several routines for [DeepSpeech](https://github.com/mozilla/DeepSpeech) features processing, like speech features generation for [VOCA](https://github.com/TimoBolkart/voca) model.
|
3 |
+
|
4 |
+
## Installation
|
5 |
+
|
6 |
+
```
|
7 |
+
pip3 install -r requirements.txt
|
8 |
+
```
|
9 |
+
|
10 |
+
## Usage
|
11 |
+
|
12 |
+
Generate wav files:
|
13 |
+
```
|
14 |
+
python3 extract_wav.py --in-video=<you_data_dir>
|
15 |
+
```
|
16 |
+
|
17 |
+
Generate files with DeepSpeech features:
|
18 |
+
```
|
19 |
+
python3 extract_ds_features.py --input=<you_data_dir>
|
20 |
+
```
|
data_utils/deepspeech_features/deepspeech_features.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
DeepSpeech features processing routines.
|
3 |
+
NB: Based on VOCA code. See the corresponding license restrictions.
|
4 |
+
"""
|
5 |
+
|
6 |
+
__all__ = ['conv_audios_to_deepspeech']
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import warnings
|
10 |
+
import resampy
|
11 |
+
from scipy.io import wavfile
|
12 |
+
from python_speech_features import mfcc
|
13 |
+
import tensorflow.compat.v1 as tf
|
14 |
+
tf.disable_v2_behavior()
|
15 |
+
|
16 |
+
def conv_audios_to_deepspeech(audios,
|
17 |
+
out_files,
|
18 |
+
num_frames_info,
|
19 |
+
deepspeech_pb_path,
|
20 |
+
audio_window_size=1,
|
21 |
+
audio_window_stride=1):
|
22 |
+
"""
|
23 |
+
Convert list of audio files into files with DeepSpeech features.
|
24 |
+
|
25 |
+
Parameters
|
26 |
+
----------
|
27 |
+
audios : list of str or list of None
|
28 |
+
Paths to input audio files.
|
29 |
+
out_files : list of str
|
30 |
+
Paths to output files with DeepSpeech features.
|
31 |
+
num_frames_info : list of int
|
32 |
+
List of numbers of frames.
|
33 |
+
deepspeech_pb_path : str
|
34 |
+
Path to DeepSpeech 0.1.0 frozen model.
|
35 |
+
audio_window_size : int, default 16
|
36 |
+
Audio window size.
|
37 |
+
audio_window_stride : int, default 1
|
38 |
+
Audio window stride.
|
39 |
+
"""
|
40 |
+
# deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
|
41 |
+
graph, logits_ph, input_node_ph, input_lengths_ph = prepare_deepspeech_net(
|
42 |
+
deepspeech_pb_path)
|
43 |
+
|
44 |
+
with tf.compat.v1.Session(graph=graph) as sess:
|
45 |
+
for audio_file_path, out_file_path, num_frames in zip(audios, out_files, num_frames_info):
|
46 |
+
print(audio_file_path)
|
47 |
+
print(out_file_path)
|
48 |
+
audio_sample_rate, audio = wavfile.read(audio_file_path)
|
49 |
+
if audio.ndim != 1:
|
50 |
+
warnings.warn(
|
51 |
+
"Audio has multiple channels, the first channel is used")
|
52 |
+
audio = audio[:, 0]
|
53 |
+
ds_features = pure_conv_audio_to_deepspeech(
|
54 |
+
audio=audio,
|
55 |
+
audio_sample_rate=audio_sample_rate,
|
56 |
+
audio_window_size=audio_window_size,
|
57 |
+
audio_window_stride=audio_window_stride,
|
58 |
+
num_frames=num_frames,
|
59 |
+
net_fn=lambda x: sess.run(
|
60 |
+
logits_ph,
|
61 |
+
feed_dict={
|
62 |
+
input_node_ph: x[np.newaxis, ...],
|
63 |
+
input_lengths_ph: [x.shape[0]]}))
|
64 |
+
|
65 |
+
net_output = ds_features.reshape(-1, 29)
|
66 |
+
win_size = 16
|
67 |
+
zero_pad = np.zeros((int(win_size / 2), net_output.shape[1]))
|
68 |
+
net_output = np.concatenate(
|
69 |
+
(zero_pad, net_output, zero_pad), axis=0)
|
70 |
+
windows = []
|
71 |
+
for window_index in range(0, net_output.shape[0] - win_size, 2):
|
72 |
+
windows.append(
|
73 |
+
net_output[window_index:window_index + win_size])
|
74 |
+
print(np.array(windows).shape)
|
75 |
+
np.save(out_file_path, np.array(windows))
|
76 |
+
|
77 |
+
|
78 |
+
def prepare_deepspeech_net(deepspeech_pb_path):
|
79 |
+
"""
|
80 |
+
Load and prepare DeepSpeech network.
|
81 |
+
|
82 |
+
Parameters
|
83 |
+
----------
|
84 |
+
deepspeech_pb_path : str
|
85 |
+
Path to DeepSpeech 0.1.0 frozen model.
|
86 |
+
|
87 |
+
Returns
|
88 |
+
-------
|
89 |
+
graph : obj
|
90 |
+
ThensorFlow graph.
|
91 |
+
logits_ph : obj
|
92 |
+
ThensorFlow placeholder for `logits`.
|
93 |
+
input_node_ph : obj
|
94 |
+
ThensorFlow placeholder for `input_node`.
|
95 |
+
input_lengths_ph : obj
|
96 |
+
ThensorFlow placeholder for `input_lengths`.
|
97 |
+
"""
|
98 |
+
# Load graph and place_holders:
|
99 |
+
with tf.io.gfile.GFile(deepspeech_pb_path, "rb") as f:
|
100 |
+
graph_def = tf.compat.v1.GraphDef()
|
101 |
+
graph_def.ParseFromString(f.read())
|
102 |
+
|
103 |
+
graph = tf.compat.v1.get_default_graph()
|
104 |
+
tf.import_graph_def(graph_def, name="deepspeech")
|
105 |
+
logits_ph = graph.get_tensor_by_name("deepspeech/logits:0")
|
106 |
+
input_node_ph = graph.get_tensor_by_name("deepspeech/input_node:0")
|
107 |
+
input_lengths_ph = graph.get_tensor_by_name("deepspeech/input_lengths:0")
|
108 |
+
|
109 |
+
return graph, logits_ph, input_node_ph, input_lengths_ph
|
110 |
+
|
111 |
+
|
112 |
+
def pure_conv_audio_to_deepspeech(audio,
|
113 |
+
audio_sample_rate,
|
114 |
+
audio_window_size,
|
115 |
+
audio_window_stride,
|
116 |
+
num_frames,
|
117 |
+
net_fn):
|
118 |
+
"""
|
119 |
+
Core routine for converting audion into DeepSpeech features.
|
120 |
+
|
121 |
+
Parameters
|
122 |
+
----------
|
123 |
+
audio : np.array
|
124 |
+
Audio data.
|
125 |
+
audio_sample_rate : int
|
126 |
+
Audio sample rate.
|
127 |
+
audio_window_size : int
|
128 |
+
Audio window size.
|
129 |
+
audio_window_stride : int
|
130 |
+
Audio window stride.
|
131 |
+
num_frames : int or None
|
132 |
+
Numbers of frames.
|
133 |
+
net_fn : func
|
134 |
+
Function for DeepSpeech model call.
|
135 |
+
|
136 |
+
Returns
|
137 |
+
-------
|
138 |
+
np.array
|
139 |
+
DeepSpeech features.
|
140 |
+
"""
|
141 |
+
target_sample_rate = 16000
|
142 |
+
if audio_sample_rate != target_sample_rate:
|
143 |
+
resampled_audio = resampy.resample(
|
144 |
+
x=audio.astype(np.float),
|
145 |
+
sr_orig=audio_sample_rate,
|
146 |
+
sr_new=target_sample_rate)
|
147 |
+
else:
|
148 |
+
resampled_audio = audio.astype(np.float32)
|
149 |
+
input_vector = conv_audio_to_deepspeech_input_vector(
|
150 |
+
audio=resampled_audio.astype(np.int16),
|
151 |
+
sample_rate=target_sample_rate,
|
152 |
+
num_cepstrum=26,
|
153 |
+
num_context=9)
|
154 |
+
|
155 |
+
network_output = net_fn(input_vector)
|
156 |
+
# print(network_output.shape)
|
157 |
+
|
158 |
+
deepspeech_fps = 50
|
159 |
+
video_fps = 50 # Change this option if video fps is different
|
160 |
+
audio_len_s = float(audio.shape[0]) / audio_sample_rate
|
161 |
+
if num_frames is None:
|
162 |
+
num_frames = int(round(audio_len_s * video_fps))
|
163 |
+
else:
|
164 |
+
video_fps = num_frames / audio_len_s
|
165 |
+
network_output = interpolate_features(
|
166 |
+
features=network_output[:, 0],
|
167 |
+
input_rate=deepspeech_fps,
|
168 |
+
output_rate=video_fps,
|
169 |
+
output_len=num_frames)
|
170 |
+
|
171 |
+
# Make windows:
|
172 |
+
zero_pad = np.zeros((int(audio_window_size / 2), network_output.shape[1]))
|
173 |
+
network_output = np.concatenate(
|
174 |
+
(zero_pad, network_output, zero_pad), axis=0)
|
175 |
+
windows = []
|
176 |
+
for window_index in range(0, network_output.shape[0] - audio_window_size, audio_window_stride):
|
177 |
+
windows.append(
|
178 |
+
network_output[window_index:window_index + audio_window_size])
|
179 |
+
|
180 |
+
return np.array(windows)
|
181 |
+
|
182 |
+
|
183 |
+
def conv_audio_to_deepspeech_input_vector(audio,
|
184 |
+
sample_rate,
|
185 |
+
num_cepstrum,
|
186 |
+
num_context):
|
187 |
+
"""
|
188 |
+
Convert audio raw data into DeepSpeech input vector.
|
189 |
+
|
190 |
+
Parameters
|
191 |
+
----------
|
192 |
+
audio : np.array
|
193 |
+
Audio data.
|
194 |
+
audio_sample_rate : int
|
195 |
+
Audio sample rate.
|
196 |
+
num_cepstrum : int
|
197 |
+
Number of cepstrum.
|
198 |
+
num_context : int
|
199 |
+
Number of context.
|
200 |
+
|
201 |
+
Returns
|
202 |
+
-------
|
203 |
+
np.array
|
204 |
+
DeepSpeech input vector.
|
205 |
+
"""
|
206 |
+
# Get mfcc coefficients:
|
207 |
+
features = mfcc(
|
208 |
+
signal=audio,
|
209 |
+
samplerate=sample_rate,
|
210 |
+
numcep=num_cepstrum)
|
211 |
+
|
212 |
+
# We only keep every second feature (BiRNN stride = 2):
|
213 |
+
features = features[::2]
|
214 |
+
|
215 |
+
# One stride per time step in the input:
|
216 |
+
num_strides = len(features)
|
217 |
+
|
218 |
+
# Add empty initial and final contexts:
|
219 |
+
empty_context = np.zeros((num_context, num_cepstrum), dtype=features.dtype)
|
220 |
+
features = np.concatenate((empty_context, features, empty_context))
|
221 |
+
|
222 |
+
# Create a view into the array with overlapping strides of size
|
223 |
+
# numcontext (past) + 1 (present) + numcontext (future):
|
224 |
+
window_size = 2 * num_context + 1
|
225 |
+
train_inputs = np.lib.stride_tricks.as_strided(
|
226 |
+
features,
|
227 |
+
shape=(num_strides, window_size, num_cepstrum),
|
228 |
+
strides=(features.strides[0],
|
229 |
+
features.strides[0], features.strides[1]),
|
230 |
+
writeable=False)
|
231 |
+
|
232 |
+
# Flatten the second and third dimensions:
|
233 |
+
train_inputs = np.reshape(train_inputs, [num_strides, -1])
|
234 |
+
|
235 |
+
train_inputs = np.copy(train_inputs)
|
236 |
+
train_inputs = (train_inputs - np.mean(train_inputs)) / \
|
237 |
+
np.std(train_inputs)
|
238 |
+
|
239 |
+
return train_inputs
|
240 |
+
|
241 |
+
|
242 |
+
def interpolate_features(features,
|
243 |
+
input_rate,
|
244 |
+
output_rate,
|
245 |
+
output_len):
|
246 |
+
"""
|
247 |
+
Interpolate DeepSpeech features.
|
248 |
+
|
249 |
+
Parameters
|
250 |
+
----------
|
251 |
+
features : np.array
|
252 |
+
DeepSpeech features.
|
253 |
+
input_rate : int
|
254 |
+
input rate (FPS).
|
255 |
+
output_rate : int
|
256 |
+
Output rate (FPS).
|
257 |
+
output_len : int
|
258 |
+
Output data length.
|
259 |
+
|
260 |
+
Returns
|
261 |
+
-------
|
262 |
+
np.array
|
263 |
+
Interpolated data.
|
264 |
+
"""
|
265 |
+
input_len = features.shape[0]
|
266 |
+
num_features = features.shape[1]
|
267 |
+
input_timestamps = np.arange(input_len) / float(input_rate)
|
268 |
+
output_timestamps = np.arange(output_len) / float(output_rate)
|
269 |
+
output_features = np.zeros((output_len, num_features))
|
270 |
+
for feature_idx in range(num_features):
|
271 |
+
output_features[:, feature_idx] = np.interp(
|
272 |
+
x=output_timestamps,
|
273 |
+
xp=input_timestamps,
|
274 |
+
fp=features[:, feature_idx])
|
275 |
+
return output_features
|
data_utils/deepspeech_features/deepspeech_store.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Routines for loading DeepSpeech model.
|
3 |
+
"""
|
4 |
+
|
5 |
+
__all__ = ['get_deepspeech_model_file']
|
6 |
+
|
7 |
+
import os
|
8 |
+
import zipfile
|
9 |
+
import logging
|
10 |
+
import hashlib
|
11 |
+
|
12 |
+
|
13 |
+
deepspeech_features_repo_url = 'https://github.com/osmr/deepspeech_features'
|
14 |
+
|
15 |
+
|
16 |
+
def get_deepspeech_model_file(local_model_store_dir_path=os.path.join("~", ".tensorflow", "models")):
|
17 |
+
"""
|
18 |
+
Return location for the pretrained on local file system. This function will download from online model zoo when
|
19 |
+
model cannot be found or has mismatch. The root directory will be created if it doesn't exist.
|
20 |
+
|
21 |
+
Parameters
|
22 |
+
----------
|
23 |
+
local_model_store_dir_path : str, default $TENSORFLOW_HOME/models
|
24 |
+
Location for keeping the model parameters.
|
25 |
+
|
26 |
+
Returns
|
27 |
+
-------
|
28 |
+
file_path
|
29 |
+
Path to the requested pretrained model file.
|
30 |
+
"""
|
31 |
+
sha1_hash = "b90017e816572ddce84f5843f1fa21e6a377975e"
|
32 |
+
file_name = "deepspeech-0_1_0-b90017e8.pb"
|
33 |
+
local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path)
|
34 |
+
file_path = os.path.join(local_model_store_dir_path, file_name)
|
35 |
+
if os.path.exists(file_path):
|
36 |
+
if _check_sha1(file_path, sha1_hash):
|
37 |
+
return file_path
|
38 |
+
else:
|
39 |
+
logging.warning("Mismatch in the content of model file detected. Downloading again.")
|
40 |
+
else:
|
41 |
+
logging.info("Model file not found. Downloading to {}.".format(file_path))
|
42 |
+
|
43 |
+
if not os.path.exists(local_model_store_dir_path):
|
44 |
+
os.makedirs(local_model_store_dir_path)
|
45 |
+
|
46 |
+
zip_file_path = file_path + ".zip"
|
47 |
+
_download(
|
48 |
+
url="{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip".format(
|
49 |
+
repo_url=deepspeech_features_repo_url,
|
50 |
+
repo_release_tag="v0.0.1",
|
51 |
+
file_name=file_name),
|
52 |
+
path=zip_file_path,
|
53 |
+
overwrite=True)
|
54 |
+
with zipfile.ZipFile(zip_file_path) as zf:
|
55 |
+
zf.extractall(local_model_store_dir_path)
|
56 |
+
os.remove(zip_file_path)
|
57 |
+
|
58 |
+
if _check_sha1(file_path, sha1_hash):
|
59 |
+
return file_path
|
60 |
+
else:
|
61 |
+
raise ValueError("Downloaded file has different hash. Please try again.")
|
62 |
+
|
63 |
+
|
64 |
+
def _download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
|
65 |
+
"""
|
66 |
+
Download an given URL
|
67 |
+
|
68 |
+
Parameters
|
69 |
+
----------
|
70 |
+
url : str
|
71 |
+
URL to download
|
72 |
+
path : str, optional
|
73 |
+
Destination path to store downloaded file. By default stores to the
|
74 |
+
current directory with same name as in url.
|
75 |
+
overwrite : bool, optional
|
76 |
+
Whether to overwrite destination file if already exists.
|
77 |
+
sha1_hash : str, optional
|
78 |
+
Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
|
79 |
+
but doesn't match.
|
80 |
+
retries : integer, default 5
|
81 |
+
The number of times to attempt the download in case of failure or non 200 return codes
|
82 |
+
verify_ssl : bool, default True
|
83 |
+
Verify SSL certificates.
|
84 |
+
|
85 |
+
Returns
|
86 |
+
-------
|
87 |
+
str
|
88 |
+
The file path of the downloaded file.
|
89 |
+
"""
|
90 |
+
import warnings
|
91 |
+
try:
|
92 |
+
import requests
|
93 |
+
except ImportError:
|
94 |
+
class requests_failed_to_import(object):
|
95 |
+
pass
|
96 |
+
requests = requests_failed_to_import
|
97 |
+
|
98 |
+
if path is None:
|
99 |
+
fname = url.split("/")[-1]
|
100 |
+
# Empty filenames are invalid
|
101 |
+
assert fname, "Can't construct file-name from this URL. Please set the `path` option manually."
|
102 |
+
else:
|
103 |
+
path = os.path.expanduser(path)
|
104 |
+
if os.path.isdir(path):
|
105 |
+
fname = os.path.join(path, url.split("/")[-1])
|
106 |
+
else:
|
107 |
+
fname = path
|
108 |
+
assert retries >= 0, "Number of retries should be at least 0"
|
109 |
+
|
110 |
+
if not verify_ssl:
|
111 |
+
warnings.warn(
|
112 |
+
"Unverified HTTPS request is being made (verify_ssl=False). "
|
113 |
+
"Adding certificate verification is strongly advised.")
|
114 |
+
|
115 |
+
if overwrite or not os.path.exists(fname) or (sha1_hash and not _check_sha1(fname, sha1_hash)):
|
116 |
+
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
|
117 |
+
if not os.path.exists(dirname):
|
118 |
+
os.makedirs(dirname)
|
119 |
+
while retries + 1 > 0:
|
120 |
+
# Disable pyling too broad Exception
|
121 |
+
# pylint: disable=W0703
|
122 |
+
try:
|
123 |
+
print("Downloading {} from {}...".format(fname, url))
|
124 |
+
r = requests.get(url, stream=True, verify=verify_ssl)
|
125 |
+
if r.status_code != 200:
|
126 |
+
raise RuntimeError("Failed downloading url {}".format(url))
|
127 |
+
with open(fname, "wb") as f:
|
128 |
+
for chunk in r.iter_content(chunk_size=1024):
|
129 |
+
if chunk: # filter out keep-alive new chunks
|
130 |
+
f.write(chunk)
|
131 |
+
if sha1_hash and not _check_sha1(fname, sha1_hash):
|
132 |
+
raise UserWarning("File {} is downloaded but the content hash does not match."
|
133 |
+
" The repo may be outdated or download may be incomplete. "
|
134 |
+
"If the `repo_url` is overridden, consider switching to "
|
135 |
+
"the default repo.".format(fname))
|
136 |
+
break
|
137 |
+
except Exception as e:
|
138 |
+
retries -= 1
|
139 |
+
if retries <= 0:
|
140 |
+
raise e
|
141 |
+
else:
|
142 |
+
print("download failed, retrying, {} attempt{} left"
|
143 |
+
.format(retries, "s" if retries > 1 else ""))
|
144 |
+
|
145 |
+
return fname
|
146 |
+
|
147 |
+
|
148 |
+
def _check_sha1(filename, sha1_hash):
|
149 |
+
"""
|
150 |
+
Check whether the sha1 hash of the file content matches the expected hash.
|
151 |
+
|
152 |
+
Parameters
|
153 |
+
----------
|
154 |
+
filename : str
|
155 |
+
Path to the file.
|
156 |
+
sha1_hash : str
|
157 |
+
Expected sha1 hash in hexadecimal digits.
|
158 |
+
|
159 |
+
Returns
|
160 |
+
-------
|
161 |
+
bool
|
162 |
+
Whether the file content matches the expected hash.
|
163 |
+
"""
|
164 |
+
sha1 = hashlib.sha1()
|
165 |
+
with open(filename, "rb") as f:
|
166 |
+
while True:
|
167 |
+
data = f.read(1048576)
|
168 |
+
if not data:
|
169 |
+
break
|
170 |
+
sha1.update(data)
|
171 |
+
|
172 |
+
return sha1.hexdigest() == sha1_hash
|
data_utils/deepspeech_features/extract_ds_features.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Script for extracting DeepSpeech features from audio file.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import argparse
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
from deepspeech_store import get_deepspeech_model_file
|
10 |
+
from deepspeech_features import conv_audios_to_deepspeech
|
11 |
+
|
12 |
+
|
13 |
+
def parse_args():
|
14 |
+
"""
|
15 |
+
Create python script parameters.
|
16 |
+
Returns
|
17 |
+
-------
|
18 |
+
ArgumentParser
|
19 |
+
Resulted args.
|
20 |
+
"""
|
21 |
+
parser = argparse.ArgumentParser(
|
22 |
+
description="Extract DeepSpeech features from audio file",
|
23 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
24 |
+
parser.add_argument(
|
25 |
+
"--input",
|
26 |
+
type=str,
|
27 |
+
required=True,
|
28 |
+
help="path to input audio file or directory")
|
29 |
+
parser.add_argument(
|
30 |
+
"--output",
|
31 |
+
type=str,
|
32 |
+
help="path to output file with DeepSpeech features")
|
33 |
+
parser.add_argument(
|
34 |
+
"--deepspeech",
|
35 |
+
type=str,
|
36 |
+
help="path to DeepSpeech 0.1.0 frozen model")
|
37 |
+
parser.add_argument(
|
38 |
+
"--metainfo",
|
39 |
+
type=str,
|
40 |
+
help="path to file with meta-information")
|
41 |
+
|
42 |
+
args = parser.parse_args()
|
43 |
+
return args
|
44 |
+
|
45 |
+
|
46 |
+
def extract_features(in_audios,
|
47 |
+
out_files,
|
48 |
+
deepspeech_pb_path,
|
49 |
+
metainfo_file_path=None):
|
50 |
+
"""
|
51 |
+
Real extract audio from video file.
|
52 |
+
Parameters
|
53 |
+
----------
|
54 |
+
in_audios : list of str
|
55 |
+
Paths to input audio files.
|
56 |
+
out_files : list of str
|
57 |
+
Paths to output files with DeepSpeech features.
|
58 |
+
deepspeech_pb_path : str
|
59 |
+
Path to DeepSpeech 0.1.0 frozen model.
|
60 |
+
metainfo_file_path : str, default None
|
61 |
+
Path to file with meta-information.
|
62 |
+
"""
|
63 |
+
#deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
|
64 |
+
if metainfo_file_path is None:
|
65 |
+
num_frames_info = [None] * len(in_audios)
|
66 |
+
else:
|
67 |
+
train_df = pd.read_csv(
|
68 |
+
metainfo_file_path,
|
69 |
+
sep="\t",
|
70 |
+
index_col=False,
|
71 |
+
dtype={"Id": np.int, "File": np.unicode, "Count": np.int})
|
72 |
+
num_frames_info = train_df["Count"].values
|
73 |
+
assert (len(num_frames_info) == len(in_audios))
|
74 |
+
|
75 |
+
for i, in_audio in enumerate(in_audios):
|
76 |
+
if not out_files[i]:
|
77 |
+
file_stem, _ = os.path.splitext(in_audio)
|
78 |
+
out_files[i] = file_stem + ".npy"
|
79 |
+
#print(out_files[i])
|
80 |
+
conv_audios_to_deepspeech(
|
81 |
+
audios=in_audios,
|
82 |
+
out_files=out_files,
|
83 |
+
num_frames_info=num_frames_info,
|
84 |
+
deepspeech_pb_path=deepspeech_pb_path)
|
85 |
+
|
86 |
+
|
87 |
+
def main():
|
88 |
+
"""
|
89 |
+
Main body of script.
|
90 |
+
"""
|
91 |
+
args = parse_args()
|
92 |
+
in_audio = os.path.expanduser(args.input)
|
93 |
+
if not os.path.exists(in_audio):
|
94 |
+
raise Exception("Input file/directory doesn't exist: {}".format(in_audio))
|
95 |
+
deepspeech_pb_path = args.deepspeech
|
96 |
+
#add
|
97 |
+
deepspeech_pb_path = True
|
98 |
+
args.deepspeech = '~/.tensorflow/models/deepspeech-0_1_0-b90017e8.pb'
|
99 |
+
#deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
|
100 |
+
if deepspeech_pb_path is None:
|
101 |
+
deepspeech_pb_path = ""
|
102 |
+
if deepspeech_pb_path:
|
103 |
+
deepspeech_pb_path = os.path.expanduser(args.deepspeech)
|
104 |
+
if not os.path.exists(deepspeech_pb_path):
|
105 |
+
deepspeech_pb_path = get_deepspeech_model_file()
|
106 |
+
if os.path.isfile(in_audio):
|
107 |
+
extract_features(
|
108 |
+
in_audios=[in_audio],
|
109 |
+
out_files=[args.output],
|
110 |
+
deepspeech_pb_path=deepspeech_pb_path,
|
111 |
+
metainfo_file_path=args.metainfo)
|
112 |
+
else:
|
113 |
+
audio_file_paths = []
|
114 |
+
for file_name in os.listdir(in_audio):
|
115 |
+
if not os.path.isfile(os.path.join(in_audio, file_name)):
|
116 |
+
continue
|
117 |
+
_, file_ext = os.path.splitext(file_name)
|
118 |
+
if file_ext.lower() == ".wav":
|
119 |
+
audio_file_path = os.path.join(in_audio, file_name)
|
120 |
+
audio_file_paths.append(audio_file_path)
|
121 |
+
audio_file_paths = sorted(audio_file_paths)
|
122 |
+
out_file_paths = [""] * len(audio_file_paths)
|
123 |
+
extract_features(
|
124 |
+
in_audios=audio_file_paths,
|
125 |
+
out_files=out_file_paths,
|
126 |
+
deepspeech_pb_path=deepspeech_pb_path,
|
127 |
+
metainfo_file_path=args.metainfo)
|
128 |
+
|
129 |
+
|
130 |
+
if __name__ == "__main__":
|
131 |
+
main()
|
132 |
+
|
data_utils/deepspeech_features/extract_wav.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Script for extracting audio (16-bit, mono, 22000 Hz) from video file.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import argparse
|
7 |
+
import subprocess
|
8 |
+
|
9 |
+
|
10 |
+
def parse_args():
|
11 |
+
"""
|
12 |
+
Create python script parameters.
|
13 |
+
|
14 |
+
Returns
|
15 |
+
-------
|
16 |
+
ArgumentParser
|
17 |
+
Resulted args.
|
18 |
+
"""
|
19 |
+
parser = argparse.ArgumentParser(
|
20 |
+
description="Extract audio from video file",
|
21 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
22 |
+
parser.add_argument(
|
23 |
+
"--in-video",
|
24 |
+
type=str,
|
25 |
+
required=True,
|
26 |
+
help="path to input video file or directory")
|
27 |
+
parser.add_argument(
|
28 |
+
"--out-audio",
|
29 |
+
type=str,
|
30 |
+
help="path to output audio file")
|
31 |
+
|
32 |
+
args = parser.parse_args()
|
33 |
+
return args
|
34 |
+
|
35 |
+
|
36 |
+
def extract_audio(in_video,
|
37 |
+
out_audio):
|
38 |
+
"""
|
39 |
+
Real extract audio from video file.
|
40 |
+
|
41 |
+
Parameters
|
42 |
+
----------
|
43 |
+
in_video : str
|
44 |
+
Path to input video file.
|
45 |
+
out_audio : str
|
46 |
+
Path to output audio file.
|
47 |
+
"""
|
48 |
+
if not out_audio:
|
49 |
+
file_stem, _ = os.path.splitext(in_video)
|
50 |
+
out_audio = file_stem + ".wav"
|
51 |
+
# command1 = "ffmpeg -i {in_video} -vn -acodec copy {aac_audio}"
|
52 |
+
# command2 = "ffmpeg -i {aac_audio} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}"
|
53 |
+
# command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}"
|
54 |
+
command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 16000 {out_audio}"
|
55 |
+
subprocess.call([command.format(in_video=in_video, out_audio=out_audio)], shell=True)
|
56 |
+
|
57 |
+
|
58 |
+
def main():
|
59 |
+
"""
|
60 |
+
Main body of script.
|
61 |
+
"""
|
62 |
+
args = parse_args()
|
63 |
+
in_video = os.path.expanduser(args.in_video)
|
64 |
+
if not os.path.exists(in_video):
|
65 |
+
raise Exception("Input file/directory doesn't exist: {}".format(in_video))
|
66 |
+
if os.path.isfile(in_video):
|
67 |
+
extract_audio(
|
68 |
+
in_video=in_video,
|
69 |
+
out_audio=args.out_audio)
|
70 |
+
else:
|
71 |
+
video_file_paths = []
|
72 |
+
for file_name in os.listdir(in_video):
|
73 |
+
if not os.path.isfile(os.path.join(in_video, file_name)):
|
74 |
+
continue
|
75 |
+
_, file_ext = os.path.splitext(file_name)
|
76 |
+
if file_ext.lower() in (".mp4", ".mkv", ".avi"):
|
77 |
+
video_file_path = os.path.join(in_video, file_name)
|
78 |
+
video_file_paths.append(video_file_path)
|
79 |
+
video_file_paths = sorted(video_file_paths)
|
80 |
+
for video_file_path in video_file_paths:
|
81 |
+
extract_audio(
|
82 |
+
in_video=video_file_path,
|
83 |
+
out_audio="")
|
84 |
+
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
main()
|
data_utils/deepspeech_features/fea_win.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
net_output = np.load('french.ds.npy').reshape(-1, 29)
|
4 |
+
win_size = 16
|
5 |
+
zero_pad = np.zeros((int(win_size / 2), net_output.shape[1]))
|
6 |
+
net_output = np.concatenate((zero_pad, net_output, zero_pad), axis=0)
|
7 |
+
windows = []
|
8 |
+
for window_index in range(0, net_output.shape[0] - win_size, 2):
|
9 |
+
windows.append(net_output[window_index:window_index + win_size])
|
10 |
+
print(np.array(windows).shape)
|
11 |
+
np.save('aud_french.npy', np.array(windows))
|
data_utils/face_parsing/79999_iter.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:468e13ca13a9b43cc0881a9f99083a430e9c0a38abd935431d1c28ee94b26567
|
3 |
+
size 53289463
|
data_utils/face_parsing/logger.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
import os.path as osp
|
6 |
+
import time
|
7 |
+
import sys
|
8 |
+
import logging
|
9 |
+
|
10 |
+
import torch.distributed as dist
|
11 |
+
|
12 |
+
|
13 |
+
def setup_logger(logpth):
|
14 |
+
logfile = 'BiSeNet-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S'))
|
15 |
+
logfile = osp.join(logpth, logfile)
|
16 |
+
FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s'
|
17 |
+
log_level = logging.INFO
|
18 |
+
if dist.is_initialized() and not dist.get_rank()==0:
|
19 |
+
log_level = logging.ERROR
|
20 |
+
logging.basicConfig(level=log_level, format=FORMAT, filename=logfile)
|
21 |
+
logging.root.addHandler(logging.StreamHandler())
|
22 |
+
|
23 |
+
|
data_utils/face_parsing/model.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torchvision
|
9 |
+
|
10 |
+
from resnet import Resnet18
|
11 |
+
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
12 |
+
|
13 |
+
|
14 |
+
class ConvBNReLU(nn.Module):
|
15 |
+
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
|
16 |
+
super(ConvBNReLU, self).__init__()
|
17 |
+
self.conv = nn.Conv2d(in_chan,
|
18 |
+
out_chan,
|
19 |
+
kernel_size = ks,
|
20 |
+
stride = stride,
|
21 |
+
padding = padding,
|
22 |
+
bias = False)
|
23 |
+
self.bn = nn.BatchNorm2d(out_chan)
|
24 |
+
self.init_weight()
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
x = self.conv(x)
|
28 |
+
x = F.relu(self.bn(x))
|
29 |
+
return x
|
30 |
+
|
31 |
+
def init_weight(self):
|
32 |
+
for ly in self.children():
|
33 |
+
if isinstance(ly, nn.Conv2d):
|
34 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
35 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
36 |
+
|
37 |
+
class BiSeNetOutput(nn.Module):
|
38 |
+
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
|
39 |
+
super(BiSeNetOutput, self).__init__()
|
40 |
+
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
|
41 |
+
self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
|
42 |
+
self.init_weight()
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
x = self.conv(x)
|
46 |
+
x = self.conv_out(x)
|
47 |
+
return x
|
48 |
+
|
49 |
+
def init_weight(self):
|
50 |
+
for ly in self.children():
|
51 |
+
if isinstance(ly, nn.Conv2d):
|
52 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
53 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
54 |
+
|
55 |
+
def get_params(self):
|
56 |
+
wd_params, nowd_params = [], []
|
57 |
+
for name, module in self.named_modules():
|
58 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
59 |
+
wd_params.append(module.weight)
|
60 |
+
if not module.bias is None:
|
61 |
+
nowd_params.append(module.bias)
|
62 |
+
elif isinstance(module, nn.BatchNorm2d):
|
63 |
+
nowd_params += list(module.parameters())
|
64 |
+
return wd_params, nowd_params
|
65 |
+
|
66 |
+
|
67 |
+
class AttentionRefinementModule(nn.Module):
|
68 |
+
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
69 |
+
super(AttentionRefinementModule, self).__init__()
|
70 |
+
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
|
71 |
+
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
|
72 |
+
self.bn_atten = nn.BatchNorm2d(out_chan)
|
73 |
+
self.sigmoid_atten = nn.Sigmoid()
|
74 |
+
self.init_weight()
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
feat = self.conv(x)
|
78 |
+
atten = F.avg_pool2d(feat, feat.size()[2:])
|
79 |
+
atten = self.conv_atten(atten)
|
80 |
+
atten = self.bn_atten(atten)
|
81 |
+
atten = self.sigmoid_atten(atten)
|
82 |
+
out = torch.mul(feat, atten)
|
83 |
+
return out
|
84 |
+
|
85 |
+
def init_weight(self):
|
86 |
+
for ly in self.children():
|
87 |
+
if isinstance(ly, nn.Conv2d):
|
88 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
89 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
90 |
+
|
91 |
+
|
92 |
+
class ContextPath(nn.Module):
|
93 |
+
def __init__(self, *args, **kwargs):
|
94 |
+
super(ContextPath, self).__init__()
|
95 |
+
self.resnet = Resnet18()
|
96 |
+
self.arm16 = AttentionRefinementModule(256, 128)
|
97 |
+
self.arm32 = AttentionRefinementModule(512, 128)
|
98 |
+
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
99 |
+
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
100 |
+
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
|
101 |
+
|
102 |
+
self.init_weight()
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
H0, W0 = x.size()[2:]
|
106 |
+
feat8, feat16, feat32 = self.resnet(x)
|
107 |
+
H8, W8 = feat8.size()[2:]
|
108 |
+
H16, W16 = feat16.size()[2:]
|
109 |
+
H32, W32 = feat32.size()[2:]
|
110 |
+
|
111 |
+
avg = F.avg_pool2d(feat32, feat32.size()[2:])
|
112 |
+
avg = self.conv_avg(avg)
|
113 |
+
avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
|
114 |
+
|
115 |
+
feat32_arm = self.arm32(feat32)
|
116 |
+
feat32_sum = feat32_arm + avg_up
|
117 |
+
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
|
118 |
+
feat32_up = self.conv_head32(feat32_up)
|
119 |
+
|
120 |
+
feat16_arm = self.arm16(feat16)
|
121 |
+
feat16_sum = feat16_arm + feat32_up
|
122 |
+
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
|
123 |
+
feat16_up = self.conv_head16(feat16_up)
|
124 |
+
|
125 |
+
return feat8, feat16_up, feat32_up # x8, x8, x16
|
126 |
+
|
127 |
+
def init_weight(self):
|
128 |
+
for ly in self.children():
|
129 |
+
if isinstance(ly, nn.Conv2d):
|
130 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
131 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
132 |
+
|
133 |
+
def get_params(self):
|
134 |
+
wd_params, nowd_params = [], []
|
135 |
+
for name, module in self.named_modules():
|
136 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
137 |
+
wd_params.append(module.weight)
|
138 |
+
if not module.bias is None:
|
139 |
+
nowd_params.append(module.bias)
|
140 |
+
elif isinstance(module, nn.BatchNorm2d):
|
141 |
+
nowd_params += list(module.parameters())
|
142 |
+
return wd_params, nowd_params
|
143 |
+
|
144 |
+
|
145 |
+
### This is not used, since I replace this with the resnet feature with the same size
|
146 |
+
class SpatialPath(nn.Module):
|
147 |
+
def __init__(self, *args, **kwargs):
|
148 |
+
super(SpatialPath, self).__init__()
|
149 |
+
self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
|
150 |
+
self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
151 |
+
self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
152 |
+
self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
|
153 |
+
self.init_weight()
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
feat = self.conv1(x)
|
157 |
+
feat = self.conv2(feat)
|
158 |
+
feat = self.conv3(feat)
|
159 |
+
feat = self.conv_out(feat)
|
160 |
+
return feat
|
161 |
+
|
162 |
+
def init_weight(self):
|
163 |
+
for ly in self.children():
|
164 |
+
if isinstance(ly, nn.Conv2d):
|
165 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
166 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
167 |
+
|
168 |
+
def get_params(self):
|
169 |
+
wd_params, nowd_params = [], []
|
170 |
+
for name, module in self.named_modules():
|
171 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
172 |
+
wd_params.append(module.weight)
|
173 |
+
if not module.bias is None:
|
174 |
+
nowd_params.append(module.bias)
|
175 |
+
elif isinstance(module, nn.BatchNorm2d):
|
176 |
+
nowd_params += list(module.parameters())
|
177 |
+
return wd_params, nowd_params
|
178 |
+
|
179 |
+
|
180 |
+
class FeatureFusionModule(nn.Module):
|
181 |
+
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
182 |
+
super(FeatureFusionModule, self).__init__()
|
183 |
+
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
|
184 |
+
self.conv1 = nn.Conv2d(out_chan,
|
185 |
+
out_chan//4,
|
186 |
+
kernel_size = 1,
|
187 |
+
stride = 1,
|
188 |
+
padding = 0,
|
189 |
+
bias = False)
|
190 |
+
self.conv2 = nn.Conv2d(out_chan//4,
|
191 |
+
out_chan,
|
192 |
+
kernel_size = 1,
|
193 |
+
stride = 1,
|
194 |
+
padding = 0,
|
195 |
+
bias = False)
|
196 |
+
self.relu = nn.ReLU(inplace=True)
|
197 |
+
self.sigmoid = nn.Sigmoid()
|
198 |
+
self.init_weight()
|
199 |
+
|
200 |
+
def forward(self, fsp, fcp):
|
201 |
+
fcat = torch.cat([fsp, fcp], dim=1)
|
202 |
+
feat = self.convblk(fcat)
|
203 |
+
atten = F.avg_pool2d(feat, feat.size()[2:])
|
204 |
+
atten = self.conv1(atten)
|
205 |
+
atten = self.relu(atten)
|
206 |
+
atten = self.conv2(atten)
|
207 |
+
atten = self.sigmoid(atten)
|
208 |
+
feat_atten = torch.mul(feat, atten)
|
209 |
+
feat_out = feat_atten + feat
|
210 |
+
return feat_out
|
211 |
+
|
212 |
+
def init_weight(self):
|
213 |
+
for ly in self.children():
|
214 |
+
if isinstance(ly, nn.Conv2d):
|
215 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
216 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
217 |
+
|
218 |
+
def get_params(self):
|
219 |
+
wd_params, nowd_params = [], []
|
220 |
+
for name, module in self.named_modules():
|
221 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
222 |
+
wd_params.append(module.weight)
|
223 |
+
if not module.bias is None:
|
224 |
+
nowd_params.append(module.bias)
|
225 |
+
elif isinstance(module, nn.BatchNorm2d):
|
226 |
+
nowd_params += list(module.parameters())
|
227 |
+
return wd_params, nowd_params
|
228 |
+
|
229 |
+
|
230 |
+
class BiSeNet(nn.Module):
|
231 |
+
def __init__(self, n_classes, *args, **kwargs):
|
232 |
+
super(BiSeNet, self).__init__()
|
233 |
+
self.cp = ContextPath()
|
234 |
+
## here self.sp is deleted
|
235 |
+
self.ffm = FeatureFusionModule(256, 256)
|
236 |
+
self.conv_out = BiSeNetOutput(256, 256, n_classes)
|
237 |
+
self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
|
238 |
+
self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
|
239 |
+
self.init_weight()
|
240 |
+
|
241 |
+
def forward(self, x):
|
242 |
+
H, W = x.size()[2:]
|
243 |
+
feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
|
244 |
+
feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
|
245 |
+
feat_fuse = self.ffm(feat_sp, feat_cp8)
|
246 |
+
|
247 |
+
feat_out = self.conv_out(feat_fuse)
|
248 |
+
feat_out16 = self.conv_out16(feat_cp8)
|
249 |
+
feat_out32 = self.conv_out32(feat_cp16)
|
250 |
+
|
251 |
+
feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
|
252 |
+
feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
|
253 |
+
feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
|
254 |
+
|
255 |
+
# return feat_out, feat_out16, feat_out32
|
256 |
+
return feat_out
|
257 |
+
|
258 |
+
def init_weight(self):
|
259 |
+
for ly in self.children():
|
260 |
+
if isinstance(ly, nn.Conv2d):
|
261 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
262 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
263 |
+
|
264 |
+
def get_params(self):
|
265 |
+
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
|
266 |
+
for name, child in self.named_children():
|
267 |
+
child_wd_params, child_nowd_params = child.get_params()
|
268 |
+
if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
|
269 |
+
lr_mul_wd_params += child_wd_params
|
270 |
+
lr_mul_nowd_params += child_nowd_params
|
271 |
+
else:
|
272 |
+
wd_params += child_wd_params
|
273 |
+
nowd_params += child_nowd_params
|
274 |
+
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
|
275 |
+
|
276 |
+
|
277 |
+
if __name__ == "__main__":
|
278 |
+
net = BiSeNet(19)
|
279 |
+
net.cuda()
|
280 |
+
net.eval()
|
281 |
+
in_ten = torch.randn(16, 3, 640, 480).cuda()
|
282 |
+
out, out16, out32 = net(in_ten)
|
283 |
+
print(out.shape)
|
284 |
+
|
285 |
+
net.get_params()
|
data_utils/face_parsing/resnet.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.utils.model_zoo as modelzoo
|
8 |
+
|
9 |
+
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
10 |
+
|
11 |
+
resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
|
12 |
+
|
13 |
+
|
14 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
15 |
+
"""3x3 convolution with padding"""
|
16 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
17 |
+
padding=1, bias=False)
|
18 |
+
|
19 |
+
|
20 |
+
class BasicBlock(nn.Module):
|
21 |
+
def __init__(self, in_chan, out_chan, stride=1):
|
22 |
+
super(BasicBlock, self).__init__()
|
23 |
+
self.conv1 = conv3x3(in_chan, out_chan, stride)
|
24 |
+
self.bn1 = nn.BatchNorm2d(out_chan)
|
25 |
+
self.conv2 = conv3x3(out_chan, out_chan)
|
26 |
+
self.bn2 = nn.BatchNorm2d(out_chan)
|
27 |
+
self.relu = nn.ReLU(inplace=True)
|
28 |
+
self.downsample = None
|
29 |
+
if in_chan != out_chan or stride != 1:
|
30 |
+
self.downsample = nn.Sequential(
|
31 |
+
nn.Conv2d(in_chan, out_chan,
|
32 |
+
kernel_size=1, stride=stride, bias=False),
|
33 |
+
nn.BatchNorm2d(out_chan),
|
34 |
+
)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
residual = self.conv1(x)
|
38 |
+
residual = F.relu(self.bn1(residual))
|
39 |
+
residual = self.conv2(residual)
|
40 |
+
residual = self.bn2(residual)
|
41 |
+
|
42 |
+
shortcut = x
|
43 |
+
if self.downsample is not None:
|
44 |
+
shortcut = self.downsample(x)
|
45 |
+
|
46 |
+
out = shortcut + residual
|
47 |
+
out = self.relu(out)
|
48 |
+
return out
|
49 |
+
|
50 |
+
|
51 |
+
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
|
52 |
+
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
|
53 |
+
for i in range(bnum-1):
|
54 |
+
layers.append(BasicBlock(out_chan, out_chan, stride=1))
|
55 |
+
return nn.Sequential(*layers)
|
56 |
+
|
57 |
+
|
58 |
+
class Resnet18(nn.Module):
|
59 |
+
def __init__(self):
|
60 |
+
super(Resnet18, self).__init__()
|
61 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
62 |
+
bias=False)
|
63 |
+
self.bn1 = nn.BatchNorm2d(64)
|
64 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
65 |
+
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
|
66 |
+
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
|
67 |
+
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
|
68 |
+
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
|
69 |
+
self.init_weight()
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
x = self.conv1(x)
|
73 |
+
x = F.relu(self.bn1(x))
|
74 |
+
x = self.maxpool(x)
|
75 |
+
|
76 |
+
x = self.layer1(x)
|
77 |
+
feat8 = self.layer2(x) # 1/8
|
78 |
+
feat16 = self.layer3(feat8) # 1/16
|
79 |
+
feat32 = self.layer4(feat16) # 1/32
|
80 |
+
return feat8, feat16, feat32
|
81 |
+
|
82 |
+
def init_weight(self):
|
83 |
+
state_dict = modelzoo.load_url(resnet18_url)
|
84 |
+
self_state_dict = self.state_dict()
|
85 |
+
for k, v in state_dict.items():
|
86 |
+
if 'fc' in k: continue
|
87 |
+
self_state_dict.update({k: v})
|
88 |
+
self.load_state_dict(self_state_dict)
|
89 |
+
|
90 |
+
def get_params(self):
|
91 |
+
wd_params, nowd_params = [], []
|
92 |
+
for name, module in self.named_modules():
|
93 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
94 |
+
wd_params.append(module.weight)
|
95 |
+
if not module.bias is None:
|
96 |
+
nowd_params.append(module.bias)
|
97 |
+
elif isinstance(module, nn.BatchNorm2d):
|
98 |
+
nowd_params += list(module.parameters())
|
99 |
+
return wd_params, nowd_params
|
100 |
+
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
net = Resnet18()
|
104 |
+
x = torch.randn(16, 3, 224, 224)
|
105 |
+
out = net(x)
|
106 |
+
print(out[0].size())
|
107 |
+
print(out[1].size())
|
108 |
+
print(out[2].size())
|
109 |
+
net.get_params()
|
data_utils/face_parsing/test.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
import numpy as np
|
4 |
+
from model import BiSeNet
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
import os
|
9 |
+
import os.path as osp
|
10 |
+
|
11 |
+
from PIL import Image
|
12 |
+
import torchvision.transforms as transforms
|
13 |
+
import cv2
|
14 |
+
from pathlib import Path
|
15 |
+
import configargparse
|
16 |
+
import tqdm
|
17 |
+
|
18 |
+
# import ttach as tta
|
19 |
+
|
20 |
+
def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg',
|
21 |
+
img_size=(512, 512)):
|
22 |
+
im = np.array(im)
|
23 |
+
vis_im = im.copy().astype(np.uint8)
|
24 |
+
vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
|
25 |
+
vis_parsing_anno = cv2.resize(
|
26 |
+
vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
|
27 |
+
vis_parsing_anno_color = np.zeros(
|
28 |
+
(vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + np.array([255, 255, 255]) # + 255
|
29 |
+
|
30 |
+
num_of_class = np.max(vis_parsing_anno)
|
31 |
+
# print(num_of_class)
|
32 |
+
for pi in range(1, 14):
|
33 |
+
index = np.where(vis_parsing_anno == pi)
|
34 |
+
vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0])
|
35 |
+
|
36 |
+
for pi in range(14, 16):
|
37 |
+
index = np.where(vis_parsing_anno == pi)
|
38 |
+
vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 255, 0])
|
39 |
+
for pi in range(16, 17):
|
40 |
+
index = np.where(vis_parsing_anno == pi)
|
41 |
+
vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 0, 255])
|
42 |
+
for pi in range(17, num_of_class+1):
|
43 |
+
index = np.where(vis_parsing_anno == pi)
|
44 |
+
vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0])
|
45 |
+
|
46 |
+
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
|
47 |
+
index = np.where(vis_parsing_anno == num_of_class-1)
|
48 |
+
vis_im = cv2.resize(vis_parsing_anno_color, img_size,
|
49 |
+
interpolation=cv2.INTER_NEAREST)
|
50 |
+
if save_im:
|
51 |
+
cv2.imwrite(save_path, vis_im)
|
52 |
+
|
53 |
+
|
54 |
+
def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'):
|
55 |
+
|
56 |
+
Path(respth).mkdir(parents=True, exist_ok=True)
|
57 |
+
|
58 |
+
print(f'[INFO] loading model...')
|
59 |
+
n_classes = 19
|
60 |
+
net = BiSeNet(n_classes=n_classes)
|
61 |
+
net.cuda()
|
62 |
+
net.load_state_dict(torch.load(cp))
|
63 |
+
net.eval()
|
64 |
+
|
65 |
+
to_tensor = transforms.Compose([
|
66 |
+
transforms.ToTensor(),
|
67 |
+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
68 |
+
])
|
69 |
+
|
70 |
+
image_paths = os.listdir(dspth)
|
71 |
+
|
72 |
+
with torch.no_grad():
|
73 |
+
for image_path in tqdm.tqdm(image_paths):
|
74 |
+
if image_path.endswith('.jpg') or image_path.endswith('.png'):
|
75 |
+
img = Image.open(osp.join(dspth, image_path))
|
76 |
+
ori_size = img.size
|
77 |
+
image = img.resize((512, 512), Image.BILINEAR)
|
78 |
+
image = image.convert("RGB")
|
79 |
+
img = to_tensor(image)
|
80 |
+
|
81 |
+
# test-time augmentation.
|
82 |
+
inputs = torch.unsqueeze(img, 0) # [1, 3, 512, 512]
|
83 |
+
outputs = net(inputs.cuda())
|
84 |
+
parsing = outputs.mean(0).cpu().numpy().argmax(0)
|
85 |
+
|
86 |
+
image_path = int(image_path[:-4])
|
87 |
+
image_path = str(image_path) + '.png'
|
88 |
+
|
89 |
+
vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path), img_size=ori_size)
|
90 |
+
|
91 |
+
|
92 |
+
if __name__ == "__main__":
|
93 |
+
parser = configargparse.ArgumentParser()
|
94 |
+
parser.add_argument('--respath', type=str, default='./result/', help='result path for label')
|
95 |
+
parser.add_argument('--imgpath', type=str, default='./imgs/', help='path for input images')
|
96 |
+
parser.add_argument('--modelpath', type=str, default='data_utils/face_parsing/79999_iter.pth')
|
97 |
+
args = parser.parse_args()
|
98 |
+
evaluate(respth=args.respath, dspth=args.imgpath, cp=args.modelpath)
|
data_utils/face_tracking/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
data_utils/face_tracking/3DMM/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
data_utils/face_tracking/3DMM/01_MorphableModel.mat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:37b1f0742db356a3b1568a8365a06f5b0fe0ab687ac1c3068c803666cbd4d8e2
|
3 |
+
size 240875364
|
data_utils/face_tracking/3DMM/01_MorphableModel.mat.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5338eeb60e7b6702cf4ed728090ffe892b8dde0e42bba647184b0359f6e991aa
|
3 |
+
size 240178121
|
data_utils/face_tracking/3DMM/exp_info.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c3196029ed038eb9a461df6c782125fc4d1ec1545f2e5f361891471136b6cbb6
|
3 |
+
size 33264853
|
data_utils/face_tracking/3DMM/keys_info.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:028d3c383bae129c4bdcac880e22c551a5d2436ec7db7a26f5c57148d12469e6
|
3 |
+
size 7375
|
data_utils/face_tracking/3DMM/sub_mesh.obj
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data_utils/face_tracking/3DMM/topology_info.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2edff6a6ad574d2dddf0d0815e0beabfea9369d7c2d6e53e0ba81f809b81e963
|
3 |
+
size 4145201
|
data_utils/face_tracking/__init__.py
ADDED
File without changes
|
data_utils/face_tracking/convert_BFM.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from scipy.io import loadmat
|
3 |
+
|
4 |
+
original_BFM = loadmat("3DMM/01_MorphableModel.mat")
|
5 |
+
sub_inds = np.load("3DMM/topology_info.npy", allow_pickle=True).item()["sub_inds"]
|
6 |
+
|
7 |
+
shapePC = original_BFM["shapePC"]
|
8 |
+
shapeEV = original_BFM["shapeEV"]
|
9 |
+
shapeMU = original_BFM["shapeMU"]
|
10 |
+
texPC = original_BFM["texPC"]
|
11 |
+
texEV = original_BFM["texEV"]
|
12 |
+
texMU = original_BFM["texMU"]
|
13 |
+
|
14 |
+
b_shape = shapePC.reshape(-1, 199).transpose(1, 0).reshape(199, -1, 3)
|
15 |
+
mu_shape = shapeMU.reshape(-1, 3)
|
16 |
+
|
17 |
+
b_tex = texPC.reshape(-1, 199).transpose(1, 0).reshape(199, -1, 3)
|
18 |
+
mu_tex = texMU.reshape(-1, 3)
|
19 |
+
|
20 |
+
b_shape = b_shape[:, sub_inds, :].reshape(199, -1)
|
21 |
+
mu_shape = mu_shape[sub_inds, :].reshape(-1)
|
22 |
+
b_tex = b_tex[:, sub_inds, :].reshape(199, -1)
|
23 |
+
mu_tex = mu_tex[sub_inds, :].reshape(-1)
|
24 |
+
|
25 |
+
exp_info = np.load("3DMM/exp_info.npy", allow_pickle=True).item()
|
26 |
+
np.save(
|
27 |
+
"3DMM/3DMM_info.npy",
|
28 |
+
{
|
29 |
+
"mu_shape": mu_shape,
|
30 |
+
"b_shape": b_shape,
|
31 |
+
"sig_shape": shapeEV.reshape(-1),
|
32 |
+
"mu_exp": exp_info["mu_exp"],
|
33 |
+
"b_exp": exp_info["base_exp"],
|
34 |
+
"sig_exp": exp_info["sig_exp"],
|
35 |
+
"mu_tex": mu_tex,
|
36 |
+
"b_tex": b_tex,
|
37 |
+
"sig_tex": texEV.reshape(-1),
|
38 |
+
},
|
39 |
+
)
|
data_utils/face_tracking/data_loader.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def load_dir(path, start, end):
|
7 |
+
lmss = []
|
8 |
+
imgs_paths = []
|
9 |
+
for i in range(start, end):
|
10 |
+
if os.path.isfile(os.path.join(path, str(i) + ".lms")):
|
11 |
+
lms = np.loadtxt(os.path.join(path, str(i) + ".lms"), dtype=np.float32)
|
12 |
+
lmss.append(lms)
|
13 |
+
imgs_paths.append(os.path.join(path, str(i) + ".jpg"))
|
14 |
+
lmss = np.stack(lmss)
|
15 |
+
lmss = torch.as_tensor(lmss).cuda()
|
16 |
+
return lmss, imgs_paths
|
data_utils/face_tracking/face_tracker.py
ADDED
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import cv2
|
4 |
+
import argparse
|
5 |
+
from pathlib import Path
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
from data_loader import load_dir
|
9 |
+
from facemodel import Face_3DMM
|
10 |
+
from util import *
|
11 |
+
from render_3dmm import Render_3DMM
|
12 |
+
|
13 |
+
|
14 |
+
# torch.autograd.set_detect_anomaly(True)
|
15 |
+
|
16 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
17 |
+
|
18 |
+
|
19 |
+
def set_requires_grad(tensor_list):
|
20 |
+
for tensor in tensor_list:
|
21 |
+
tensor.requires_grad = True
|
22 |
+
|
23 |
+
|
24 |
+
parser = argparse.ArgumentParser()
|
25 |
+
parser.add_argument(
|
26 |
+
"--path", type=str, default="obama/ori_imgs", help="idname of target person"
|
27 |
+
)
|
28 |
+
parser.add_argument("--img_h", type=int, default=512, help="image height")
|
29 |
+
parser.add_argument("--img_w", type=int, default=512, help="image width")
|
30 |
+
parser.add_argument("--frame_num", type=int, default=11000, help="image number")
|
31 |
+
args = parser.parse_args()
|
32 |
+
|
33 |
+
start_id = 0
|
34 |
+
end_id = args.frame_num
|
35 |
+
|
36 |
+
lms, img_paths = load_dir(args.path, start_id, end_id)
|
37 |
+
num_frames = lms.shape[0]
|
38 |
+
h, w = args.img_h, args.img_w
|
39 |
+
cxy = torch.tensor((w / 2.0, h / 2.0), dtype=torch.float).cuda()
|
40 |
+
id_dim, exp_dim, tex_dim, point_num = 100, 79, 100, 34650
|
41 |
+
model_3dmm = Face_3DMM(
|
42 |
+
os.path.join(dir_path, "3DMM"), id_dim, exp_dim, tex_dim, point_num
|
43 |
+
)
|
44 |
+
|
45 |
+
# only use one image per 40 to do fit the focal length
|
46 |
+
sel_ids = np.arange(0, num_frames, 40)
|
47 |
+
sel_num = sel_ids.shape[0]
|
48 |
+
arg_focal = 1600
|
49 |
+
arg_landis = 1e5
|
50 |
+
|
51 |
+
print(f'[INFO] fitting focal length...')
|
52 |
+
|
53 |
+
# fit the focal length
|
54 |
+
for focal in range(600, 1500, 100):
|
55 |
+
id_para = lms.new_zeros((1, id_dim), requires_grad=True)
|
56 |
+
exp_para = lms.new_zeros((sel_num, exp_dim), requires_grad=True)
|
57 |
+
euler_angle = lms.new_zeros((sel_num, 3), requires_grad=True)
|
58 |
+
trans = lms.new_zeros((sel_num, 3), requires_grad=True)
|
59 |
+
trans.data[:, 2] -= 7
|
60 |
+
focal_length = lms.new_zeros(1, requires_grad=False)
|
61 |
+
focal_length.data += focal
|
62 |
+
set_requires_grad([id_para, exp_para, euler_angle, trans])
|
63 |
+
|
64 |
+
optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=0.1)
|
65 |
+
optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=0.1)
|
66 |
+
|
67 |
+
for iter in range(2000):
|
68 |
+
id_para_batch = id_para.expand(sel_num, -1)
|
69 |
+
geometry = model_3dmm.get_3dlandmarks(
|
70 |
+
id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
|
71 |
+
)
|
72 |
+
proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
|
73 |
+
loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms[sel_ids].detach())
|
74 |
+
loss = loss_lan
|
75 |
+
optimizer_frame.zero_grad()
|
76 |
+
loss.backward()
|
77 |
+
optimizer_frame.step()
|
78 |
+
# if iter % 100 == 0:
|
79 |
+
# print(focal, 'pose', iter, loss.item())
|
80 |
+
|
81 |
+
for iter in range(2500):
|
82 |
+
id_para_batch = id_para.expand(sel_num, -1)
|
83 |
+
geometry = model_3dmm.get_3dlandmarks(
|
84 |
+
id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
|
85 |
+
)
|
86 |
+
proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
|
87 |
+
loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms[sel_ids].detach())
|
88 |
+
loss_regid = torch.mean(id_para * id_para)
|
89 |
+
loss_regexp = torch.mean(exp_para * exp_para)
|
90 |
+
loss = loss_lan + loss_regid * 0.5 + loss_regexp * 0.4
|
91 |
+
optimizer_idexp.zero_grad()
|
92 |
+
optimizer_frame.zero_grad()
|
93 |
+
loss.backward()
|
94 |
+
optimizer_idexp.step()
|
95 |
+
optimizer_frame.step()
|
96 |
+
# if iter % 100 == 0:
|
97 |
+
# print(focal, 'poseidexp', iter, loss_lan.item(), loss_regid.item(), loss_regexp.item())
|
98 |
+
|
99 |
+
if iter % 1500 == 0 and iter >= 1500:
|
100 |
+
for param_group in optimizer_idexp.param_groups:
|
101 |
+
param_group["lr"] *= 0.2
|
102 |
+
for param_group in optimizer_frame.param_groups:
|
103 |
+
param_group["lr"] *= 0.2
|
104 |
+
|
105 |
+
print(focal, loss_lan.item(), torch.mean(trans[:, 2]).item())
|
106 |
+
|
107 |
+
if loss_lan.item() < arg_landis:
|
108 |
+
arg_landis = loss_lan.item()
|
109 |
+
arg_focal = focal
|
110 |
+
|
111 |
+
print("[INFO] find best focal:", arg_focal)
|
112 |
+
|
113 |
+
print(f'[INFO] coarse fitting...')
|
114 |
+
|
115 |
+
# for all frames, do a coarse fitting ???
|
116 |
+
id_para = lms.new_zeros((1, id_dim), requires_grad=True)
|
117 |
+
exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
|
118 |
+
tex_para = lms.new_zeros(
|
119 |
+
(1, tex_dim), requires_grad=True
|
120 |
+
) # not optimized in this block ???
|
121 |
+
euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True)
|
122 |
+
trans = lms.new_zeros((num_frames, 3), requires_grad=True)
|
123 |
+
light_para = lms.new_zeros((num_frames, 27), requires_grad=True)
|
124 |
+
trans.data[:, 2] -= 7 # ???
|
125 |
+
focal_length = lms.new_zeros(1, requires_grad=True)
|
126 |
+
focal_length.data += arg_focal
|
127 |
+
|
128 |
+
set_requires_grad([id_para, exp_para, tex_para, euler_angle, trans, light_para])
|
129 |
+
|
130 |
+
optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=0.1)
|
131 |
+
optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=1)
|
132 |
+
|
133 |
+
for iter in range(1500):
|
134 |
+
id_para_batch = id_para.expand(num_frames, -1)
|
135 |
+
geometry = model_3dmm.get_3dlandmarks(
|
136 |
+
id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
|
137 |
+
)
|
138 |
+
proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
|
139 |
+
loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms.detach())
|
140 |
+
loss = loss_lan
|
141 |
+
optimizer_frame.zero_grad()
|
142 |
+
loss.backward()
|
143 |
+
optimizer_frame.step()
|
144 |
+
if iter == 1000:
|
145 |
+
for param_group in optimizer_frame.param_groups:
|
146 |
+
param_group["lr"] = 0.1
|
147 |
+
# if iter % 100 == 0:
|
148 |
+
# print('pose', iter, loss.item())
|
149 |
+
|
150 |
+
for param_group in optimizer_frame.param_groups:
|
151 |
+
param_group["lr"] = 0.1
|
152 |
+
|
153 |
+
for iter in range(2000):
|
154 |
+
id_para_batch = id_para.expand(num_frames, -1)
|
155 |
+
geometry = model_3dmm.get_3dlandmarks(
|
156 |
+
id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
|
157 |
+
)
|
158 |
+
proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
|
159 |
+
loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms.detach())
|
160 |
+
loss_regid = torch.mean(id_para * id_para)
|
161 |
+
loss_regexp = torch.mean(exp_para * exp_para)
|
162 |
+
loss = loss_lan + loss_regid * 0.5 + loss_regexp * 0.4
|
163 |
+
optimizer_idexp.zero_grad()
|
164 |
+
optimizer_frame.zero_grad()
|
165 |
+
loss.backward()
|
166 |
+
optimizer_idexp.step()
|
167 |
+
optimizer_frame.step()
|
168 |
+
# if iter % 100 == 0:
|
169 |
+
# print('poseidexp', iter, loss_lan.item(), loss_regid.item(), loss_regexp.item())
|
170 |
+
if iter % 1000 == 0 and iter >= 1000:
|
171 |
+
for param_group in optimizer_idexp.param_groups:
|
172 |
+
param_group["lr"] *= 0.2
|
173 |
+
for param_group in optimizer_frame.param_groups:
|
174 |
+
param_group["lr"] *= 0.2
|
175 |
+
|
176 |
+
print(loss_lan.item(), torch.mean(trans[:, 2]).item())
|
177 |
+
|
178 |
+
print(f'[INFO] fitting light...')
|
179 |
+
|
180 |
+
batch_size = 32
|
181 |
+
|
182 |
+
device_default = torch.device("cuda:0")
|
183 |
+
device_render = torch.device("cuda:0")
|
184 |
+
renderer = Render_3DMM(arg_focal, h, w, batch_size, device_render)
|
185 |
+
|
186 |
+
sel_ids = np.arange(0, num_frames, int(num_frames / batch_size))[:batch_size]
|
187 |
+
imgs = []
|
188 |
+
for sel_id in sel_ids:
|
189 |
+
imgs.append(cv2.imread(img_paths[sel_id])[:, :, ::-1])
|
190 |
+
imgs = np.stack(imgs)
|
191 |
+
sel_imgs = torch.as_tensor(imgs).cuda()
|
192 |
+
sel_lms = lms[sel_ids]
|
193 |
+
sel_light = light_para.new_zeros((batch_size, 27), requires_grad=True)
|
194 |
+
set_requires_grad([sel_light])
|
195 |
+
|
196 |
+
optimizer_tl = torch.optim.Adam([tex_para, sel_light], lr=0.1)
|
197 |
+
optimizer_id_frame = torch.optim.Adam([euler_angle, trans, exp_para, id_para], lr=0.01)
|
198 |
+
|
199 |
+
for iter in range(71):
|
200 |
+
sel_exp_para, sel_euler, sel_trans = (
|
201 |
+
exp_para[sel_ids],
|
202 |
+
euler_angle[sel_ids],
|
203 |
+
trans[sel_ids],
|
204 |
+
)
|
205 |
+
sel_id_para = id_para.expand(batch_size, -1)
|
206 |
+
geometry = model_3dmm.get_3dlandmarks(
|
207 |
+
sel_id_para, sel_exp_para, sel_euler, sel_trans, focal_length, cxy
|
208 |
+
)
|
209 |
+
proj_geo = forward_transform(geometry, sel_euler, sel_trans, focal_length, cxy)
|
210 |
+
|
211 |
+
loss_lan = cal_lan_loss(proj_geo[:, :, :2], sel_lms.detach())
|
212 |
+
loss_regid = torch.mean(id_para * id_para)
|
213 |
+
loss_regexp = torch.mean(sel_exp_para * sel_exp_para)
|
214 |
+
|
215 |
+
sel_tex_para = tex_para.expand(batch_size, -1)
|
216 |
+
sel_texture = model_3dmm.forward_tex(sel_tex_para)
|
217 |
+
geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para)
|
218 |
+
rott_geo = forward_rott(geometry, sel_euler, sel_trans)
|
219 |
+
render_imgs = renderer(
|
220 |
+
rott_geo.to(device_render),
|
221 |
+
sel_texture.to(device_render),
|
222 |
+
sel_light.to(device_render),
|
223 |
+
)
|
224 |
+
render_imgs = render_imgs.to(device_default)
|
225 |
+
|
226 |
+
mask = (render_imgs[:, :, :, 3]).detach() > 0.0
|
227 |
+
render_proj = sel_imgs.clone()
|
228 |
+
render_proj[mask] = render_imgs[mask][..., :3].byte()
|
229 |
+
loss_col = cal_col_loss(render_imgs[:, :, :, :3], sel_imgs.float(), mask)
|
230 |
+
|
231 |
+
if iter > 50:
|
232 |
+
loss = loss_col + loss_lan * 0.05 + loss_regid * 1.0 + loss_regexp * 0.8
|
233 |
+
else:
|
234 |
+
loss = loss_col + loss_lan * 3 + loss_regid * 2.0 + loss_regexp * 1.0
|
235 |
+
|
236 |
+
optimizer_tl.zero_grad()
|
237 |
+
optimizer_id_frame.zero_grad()
|
238 |
+
loss.backward()
|
239 |
+
|
240 |
+
optimizer_tl.step()
|
241 |
+
optimizer_id_frame.step()
|
242 |
+
|
243 |
+
if iter % 50 == 0 and iter > 0:
|
244 |
+
for param_group in optimizer_id_frame.param_groups:
|
245 |
+
param_group["lr"] *= 0.2
|
246 |
+
for param_group in optimizer_tl.param_groups:
|
247 |
+
param_group["lr"] *= 0.2
|
248 |
+
# print(iter, loss_col.item(), loss_lan.item(), loss_regid.item(), loss_regexp.item())
|
249 |
+
|
250 |
+
|
251 |
+
light_mean = torch.mean(sel_light, 0).unsqueeze(0).repeat(num_frames, 1)
|
252 |
+
light_para.data = light_mean
|
253 |
+
|
254 |
+
exp_para = exp_para.detach()
|
255 |
+
euler_angle = euler_angle.detach()
|
256 |
+
trans = trans.detach()
|
257 |
+
light_para = light_para.detach()
|
258 |
+
|
259 |
+
print(f'[INFO] fine frame-wise fitting...')
|
260 |
+
|
261 |
+
for i in range(int((num_frames - 1) / batch_size + 1)):
|
262 |
+
|
263 |
+
if (i + 1) * batch_size > num_frames:
|
264 |
+
start_n = num_frames - batch_size
|
265 |
+
sel_ids = np.arange(num_frames - batch_size, num_frames)
|
266 |
+
else:
|
267 |
+
start_n = i * batch_size
|
268 |
+
sel_ids = np.arange(i * batch_size, i * batch_size + batch_size)
|
269 |
+
|
270 |
+
imgs = []
|
271 |
+
for sel_id in sel_ids:
|
272 |
+
imgs.append(cv2.imread(img_paths[sel_id])[:, :, ::-1])
|
273 |
+
imgs = np.stack(imgs)
|
274 |
+
sel_imgs = torch.as_tensor(imgs).cuda()
|
275 |
+
sel_lms = lms[sel_ids]
|
276 |
+
|
277 |
+
sel_exp_para = exp_para.new_zeros((batch_size, exp_dim), requires_grad=True)
|
278 |
+
sel_exp_para.data = exp_para[sel_ids].clone()
|
279 |
+
sel_euler = euler_angle.new_zeros((batch_size, 3), requires_grad=True)
|
280 |
+
sel_euler.data = euler_angle[sel_ids].clone()
|
281 |
+
sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True)
|
282 |
+
sel_trans.data = trans[sel_ids].clone()
|
283 |
+
sel_light = light_para.new_zeros((batch_size, 27), requires_grad=True)
|
284 |
+
sel_light.data = light_para[sel_ids].clone()
|
285 |
+
|
286 |
+
set_requires_grad([sel_exp_para, sel_euler, sel_trans, sel_light])
|
287 |
+
|
288 |
+
optimizer_cur_batch = torch.optim.Adam(
|
289 |
+
[sel_exp_para, sel_euler, sel_trans, sel_light], lr=0.005
|
290 |
+
)
|
291 |
+
|
292 |
+
sel_id_para = id_para.expand(batch_size, -1).detach()
|
293 |
+
sel_tex_para = tex_para.expand(batch_size, -1).detach()
|
294 |
+
|
295 |
+
pre_num = 5
|
296 |
+
|
297 |
+
if i > 0:
|
298 |
+
pre_ids = np.arange(start_n - pre_num, start_n)
|
299 |
+
|
300 |
+
for iter in range(50):
|
301 |
+
|
302 |
+
geometry = model_3dmm.get_3dlandmarks(
|
303 |
+
sel_id_para, sel_exp_para, sel_euler, sel_trans, focal_length, cxy
|
304 |
+
)
|
305 |
+
proj_geo = forward_transform(geometry, sel_euler, sel_trans, focal_length, cxy)
|
306 |
+
loss_lan = cal_lan_loss(proj_geo[:, :, :2], sel_lms.detach())
|
307 |
+
loss_regexp = torch.mean(sel_exp_para * sel_exp_para)
|
308 |
+
|
309 |
+
sel_geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para)
|
310 |
+
sel_texture = model_3dmm.forward_tex(sel_tex_para)
|
311 |
+
geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para)
|
312 |
+
rott_geo = forward_rott(geometry, sel_euler, sel_trans)
|
313 |
+
render_imgs = renderer(
|
314 |
+
rott_geo.to(device_render),
|
315 |
+
sel_texture.to(device_render),
|
316 |
+
sel_light.to(device_render),
|
317 |
+
)
|
318 |
+
render_imgs = render_imgs.to(device_default)
|
319 |
+
|
320 |
+
mask = (render_imgs[:, :, :, 3]).detach() > 0.0
|
321 |
+
|
322 |
+
loss_col = cal_col_loss(render_imgs[:, :, :, :3], sel_imgs.float(), mask)
|
323 |
+
|
324 |
+
if i > 0:
|
325 |
+
geometry_lap = model_3dmm.forward_geo_sub(
|
326 |
+
id_para.expand(batch_size + pre_num, -1).detach(),
|
327 |
+
torch.cat((exp_para[pre_ids].detach(), sel_exp_para)),
|
328 |
+
model_3dmm.rigid_ids,
|
329 |
+
)
|
330 |
+
rott_geo_lap = forward_rott(
|
331 |
+
geometry_lap,
|
332 |
+
torch.cat((euler_angle[pre_ids].detach(), sel_euler)),
|
333 |
+
torch.cat((trans[pre_ids].detach(), sel_trans)),
|
334 |
+
)
|
335 |
+
loss_lap = cal_lap_loss(
|
336 |
+
[rott_geo_lap.reshape(rott_geo_lap.shape[0], -1).permute(1, 0)], [1.0]
|
337 |
+
)
|
338 |
+
else:
|
339 |
+
geometry_lap = model_3dmm.forward_geo_sub(
|
340 |
+
id_para.expand(batch_size, -1).detach(),
|
341 |
+
sel_exp_para,
|
342 |
+
model_3dmm.rigid_ids,
|
343 |
+
)
|
344 |
+
rott_geo_lap = forward_rott(geometry_lap, sel_euler, sel_trans)
|
345 |
+
loss_lap = cal_lap_loss(
|
346 |
+
[rott_geo_lap.reshape(rott_geo_lap.shape[0], -1).permute(1, 0)], [1.0]
|
347 |
+
)
|
348 |
+
|
349 |
+
|
350 |
+
if iter > 30:
|
351 |
+
loss = loss_col * 0.5 + loss_lan * 1.5 + loss_lap * 100000 + loss_regexp * 1.0
|
352 |
+
else:
|
353 |
+
loss = loss_col * 0.5 + loss_lan * 8 + loss_lap * 100000 + loss_regexp * 1.0
|
354 |
+
|
355 |
+
optimizer_cur_batch.zero_grad()
|
356 |
+
loss.backward()
|
357 |
+
optimizer_cur_batch.step()
|
358 |
+
|
359 |
+
# if iter % 10 == 0:
|
360 |
+
# print(
|
361 |
+
# i,
|
362 |
+
# iter,
|
363 |
+
# loss_col.item(),
|
364 |
+
# loss_lan.item(),
|
365 |
+
# loss_lap.item(),
|
366 |
+
# loss_regexp.item(),
|
367 |
+
# )
|
368 |
+
|
369 |
+
print(str(i) + " of " + str(int((num_frames - 1) / batch_size + 1)) + " done")
|
370 |
+
|
371 |
+
render_proj = sel_imgs.clone()
|
372 |
+
render_proj[mask] = render_imgs[mask][..., :3].byte()
|
373 |
+
|
374 |
+
exp_para[sel_ids] = sel_exp_para.clone()
|
375 |
+
euler_angle[sel_ids] = sel_euler.clone()
|
376 |
+
trans[sel_ids] = sel_trans.clone()
|
377 |
+
light_para[sel_ids] = sel_light.clone()
|
378 |
+
|
379 |
+
torch.save(
|
380 |
+
{
|
381 |
+
"id": id_para.detach().cpu(),
|
382 |
+
"exp": exp_para.detach().cpu(),
|
383 |
+
"euler": euler_angle.detach().cpu(),
|
384 |
+
"trans": trans.detach().cpu(),
|
385 |
+
"focal": focal_length.detach().cpu(),
|
386 |
+
},
|
387 |
+
os.path.join(os.path.dirname(args.path), "track_params.pt"),
|
388 |
+
)
|
389 |
+
|
390 |
+
print("params saved")
|
data_utils/face_tracking/facemodel.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
from util import *
|
6 |
+
|
7 |
+
|
8 |
+
class Face_3DMM(nn.Module):
|
9 |
+
def __init__(self, modelpath, id_dim, exp_dim, tex_dim, point_num):
|
10 |
+
super(Face_3DMM, self).__init__()
|
11 |
+
# id_dim = 100
|
12 |
+
# exp_dim = 79
|
13 |
+
# tex_dim = 100
|
14 |
+
self.point_num = point_num
|
15 |
+
DMM_info = np.load(
|
16 |
+
os.path.join(modelpath, "3DMM_info.npy"), allow_pickle=True
|
17 |
+
).item()
|
18 |
+
base_id = DMM_info["b_shape"][:id_dim, :]
|
19 |
+
mu_id = DMM_info["mu_shape"]
|
20 |
+
base_exp = DMM_info["b_exp"][:exp_dim, :]
|
21 |
+
mu_exp = DMM_info["mu_exp"]
|
22 |
+
mu = mu_id + mu_exp
|
23 |
+
mu = mu.reshape(-1, 3)
|
24 |
+
for i in range(3):
|
25 |
+
mu[:, i] -= np.mean(mu[:, i])
|
26 |
+
mu = mu.reshape(-1)
|
27 |
+
self.base_id = torch.as_tensor(base_id).cuda() / 100000.0
|
28 |
+
self.base_exp = torch.as_tensor(base_exp).cuda() / 100000.0
|
29 |
+
self.mu = torch.as_tensor(mu).cuda() / 100000.0
|
30 |
+
base_tex = DMM_info["b_tex"][:tex_dim, :]
|
31 |
+
mu_tex = DMM_info["mu_tex"]
|
32 |
+
self.base_tex = torch.as_tensor(base_tex).cuda()
|
33 |
+
self.mu_tex = torch.as_tensor(mu_tex).cuda()
|
34 |
+
sig_id = DMM_info["sig_shape"][:id_dim]
|
35 |
+
sig_tex = DMM_info["sig_tex"][:tex_dim]
|
36 |
+
sig_exp = DMM_info["sig_exp"][:exp_dim]
|
37 |
+
self.sig_id = torch.as_tensor(sig_id).cuda()
|
38 |
+
self.sig_tex = torch.as_tensor(sig_tex).cuda()
|
39 |
+
self.sig_exp = torch.as_tensor(sig_exp).cuda()
|
40 |
+
|
41 |
+
keys_info = np.load(
|
42 |
+
os.path.join(modelpath, "keys_info.npy"), allow_pickle=True
|
43 |
+
).item()
|
44 |
+
self.keyinds = torch.as_tensor(keys_info["keyinds"]).cuda()
|
45 |
+
self.left_contours = torch.as_tensor(keys_info["left_contour"]).cuda()
|
46 |
+
self.right_contours = torch.as_tensor(keys_info["right_contour"]).cuda()
|
47 |
+
self.rigid_ids = torch.as_tensor(keys_info["rigid_ids"]).cuda()
|
48 |
+
|
49 |
+
def get_3dlandmarks(self, id_para, exp_para, euler_angle, trans, focal_length, cxy):
|
50 |
+
id_para = id_para * self.sig_id
|
51 |
+
exp_para = exp_para * self.sig_exp
|
52 |
+
batch_size = id_para.shape[0]
|
53 |
+
num_per_contour = self.left_contours.shape[1]
|
54 |
+
left_contours_flat = self.left_contours.reshape(-1)
|
55 |
+
right_contours_flat = self.right_contours.reshape(-1)
|
56 |
+
sel_index = torch.cat(
|
57 |
+
(
|
58 |
+
3 * left_contours_flat.unsqueeze(1),
|
59 |
+
3 * left_contours_flat.unsqueeze(1) + 1,
|
60 |
+
3 * left_contours_flat.unsqueeze(1) + 2,
|
61 |
+
),
|
62 |
+
dim=1,
|
63 |
+
).reshape(-1)
|
64 |
+
left_geometry = (
|
65 |
+
torch.mm(id_para, self.base_id[:, sel_index])
|
66 |
+
+ torch.mm(exp_para, self.base_exp[:, sel_index])
|
67 |
+
+ self.mu[sel_index]
|
68 |
+
)
|
69 |
+
left_geometry = left_geometry.view(batch_size, -1, 3)
|
70 |
+
proj_x = forward_transform(
|
71 |
+
left_geometry, euler_angle, trans, focal_length, cxy
|
72 |
+
)[:, :, 0]
|
73 |
+
proj_x = proj_x.reshape(batch_size, 8, num_per_contour)
|
74 |
+
arg_min = proj_x.argmin(dim=2)
|
75 |
+
left_geometry = left_geometry.view(batch_size * 8, num_per_contour, 3)
|
76 |
+
left_3dlands = left_geometry[
|
77 |
+
torch.arange(batch_size * 8), arg_min.view(-1), :
|
78 |
+
].view(batch_size, 8, 3)
|
79 |
+
|
80 |
+
sel_index = torch.cat(
|
81 |
+
(
|
82 |
+
3 * right_contours_flat.unsqueeze(1),
|
83 |
+
3 * right_contours_flat.unsqueeze(1) + 1,
|
84 |
+
3 * right_contours_flat.unsqueeze(1) + 2,
|
85 |
+
),
|
86 |
+
dim=1,
|
87 |
+
).reshape(-1)
|
88 |
+
right_geometry = (
|
89 |
+
torch.mm(id_para, self.base_id[:, sel_index])
|
90 |
+
+ torch.mm(exp_para, self.base_exp[:, sel_index])
|
91 |
+
+ self.mu[sel_index]
|
92 |
+
)
|
93 |
+
right_geometry = right_geometry.view(batch_size, -1, 3)
|
94 |
+
proj_x = forward_transform(
|
95 |
+
right_geometry, euler_angle, trans, focal_length, cxy
|
96 |
+
)[:, :, 0]
|
97 |
+
proj_x = proj_x.reshape(batch_size, 8, num_per_contour)
|
98 |
+
arg_max = proj_x.argmax(dim=2)
|
99 |
+
right_geometry = right_geometry.view(batch_size * 8, num_per_contour, 3)
|
100 |
+
right_3dlands = right_geometry[
|
101 |
+
torch.arange(batch_size * 8), arg_max.view(-1), :
|
102 |
+
].view(batch_size, 8, 3)
|
103 |
+
|
104 |
+
sel_index = torch.cat(
|
105 |
+
(
|
106 |
+
3 * self.keyinds.unsqueeze(1),
|
107 |
+
3 * self.keyinds.unsqueeze(1) + 1,
|
108 |
+
3 * self.keyinds.unsqueeze(1) + 2,
|
109 |
+
),
|
110 |
+
dim=1,
|
111 |
+
).reshape(-1)
|
112 |
+
geometry = (
|
113 |
+
torch.mm(id_para, self.base_id[:, sel_index])
|
114 |
+
+ torch.mm(exp_para, self.base_exp[:, sel_index])
|
115 |
+
+ self.mu[sel_index]
|
116 |
+
)
|
117 |
+
lands_3d = geometry.view(-1, self.keyinds.shape[0], 3)
|
118 |
+
lands_3d[:, :8, :] = left_3dlands
|
119 |
+
lands_3d[:, 9:17, :] = right_3dlands
|
120 |
+
return lands_3d
|
121 |
+
|
122 |
+
def forward_geo_sub(self, id_para, exp_para, sub_index):
|
123 |
+
id_para = id_para * self.sig_id
|
124 |
+
exp_para = exp_para * self.sig_exp
|
125 |
+
sel_index = torch.cat(
|
126 |
+
(
|
127 |
+
3 * sub_index.unsqueeze(1),
|
128 |
+
3 * sub_index.unsqueeze(1) + 1,
|
129 |
+
3 * sub_index.unsqueeze(1) + 2,
|
130 |
+
),
|
131 |
+
dim=1,
|
132 |
+
).reshape(-1)
|
133 |
+
geometry = (
|
134 |
+
torch.mm(id_para, self.base_id[:, sel_index])
|
135 |
+
+ torch.mm(exp_para, self.base_exp[:, sel_index])
|
136 |
+
+ self.mu[sel_index]
|
137 |
+
)
|
138 |
+
return geometry.reshape(-1, sub_index.shape[0], 3)
|
139 |
+
|
140 |
+
def forward_geo(self, id_para, exp_para):
|
141 |
+
id_para = id_para * self.sig_id
|
142 |
+
exp_para = exp_para * self.sig_exp
|
143 |
+
geometry = (
|
144 |
+
torch.mm(id_para, self.base_id)
|
145 |
+
+ torch.mm(exp_para, self.base_exp)
|
146 |
+
+ self.mu
|
147 |
+
)
|
148 |
+
return geometry.reshape(-1, self.point_num, 3)
|
149 |
+
|
150 |
+
def forward_tex(self, tex_para):
|
151 |
+
tex_para = tex_para * self.sig_tex
|
152 |
+
texture = torch.mm(tex_para, self.base_tex) + self.mu_tex
|
153 |
+
return texture.reshape(-1, self.point_num, 3)
|
data_utils/face_tracking/geo_transform.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This module contains functions for geometry transform and camera projection"""
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
def euler2rot(euler_angle):
|
8 |
+
batch_size = euler_angle.shape[0]
|
9 |
+
theta = euler_angle[:, 0].reshape(-1, 1, 1)
|
10 |
+
phi = euler_angle[:, 1].reshape(-1, 1, 1)
|
11 |
+
psi = euler_angle[:, 2].reshape(-1, 1, 1)
|
12 |
+
one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device)
|
13 |
+
zero = torch.zeros(
|
14 |
+
(batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device
|
15 |
+
)
|
16 |
+
rot_x = torch.cat(
|
17 |
+
(
|
18 |
+
torch.cat((one, zero, zero), 1),
|
19 |
+
torch.cat((zero, theta.cos(), theta.sin()), 1),
|
20 |
+
torch.cat((zero, -theta.sin(), theta.cos()), 1),
|
21 |
+
),
|
22 |
+
2,
|
23 |
+
)
|
24 |
+
rot_y = torch.cat(
|
25 |
+
(
|
26 |
+
torch.cat((phi.cos(), zero, -phi.sin()), 1),
|
27 |
+
torch.cat((zero, one, zero), 1),
|
28 |
+
torch.cat((phi.sin(), zero, phi.cos()), 1),
|
29 |
+
),
|
30 |
+
2,
|
31 |
+
)
|
32 |
+
rot_z = torch.cat(
|
33 |
+
(
|
34 |
+
torch.cat((psi.cos(), -psi.sin(), zero), 1),
|
35 |
+
torch.cat((psi.sin(), psi.cos(), zero), 1),
|
36 |
+
torch.cat((zero, zero, one), 1),
|
37 |
+
),
|
38 |
+
2,
|
39 |
+
)
|
40 |
+
return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
|
41 |
+
|
42 |
+
|
43 |
+
def rot_trans_geo(geometry, rot, trans):
|
44 |
+
rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans.view(-1, 3, 1)
|
45 |
+
return rott_geo.permute(0, 2, 1)
|
46 |
+
|
47 |
+
|
48 |
+
def euler_trans_geo(geometry, euler, trans):
|
49 |
+
rot = euler2rot(euler)
|
50 |
+
return rot_trans_geo(geometry, rot, trans)
|
51 |
+
|
52 |
+
|
53 |
+
def proj_geo(rott_geo, camera_para):
|
54 |
+
fx = camera_para[:, 0]
|
55 |
+
fy = camera_para[:, 0]
|
56 |
+
cx = camera_para[:, 1]
|
57 |
+
cy = camera_para[:, 2]
|
58 |
+
|
59 |
+
X = rott_geo[:, :, 0]
|
60 |
+
Y = rott_geo[:, :, 1]
|
61 |
+
Z = rott_geo[:, :, 2]
|
62 |
+
|
63 |
+
fxX = fx[:, None] * X
|
64 |
+
fyY = fy[:, None] * Y
|
65 |
+
|
66 |
+
proj_x = -fxX / Z + cx[:, None]
|
67 |
+
proj_y = fyY / Z + cy[:, None]
|
68 |
+
|
69 |
+
return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2)
|
data_utils/face_tracking/render_3dmm.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
from pytorch3d.structures import Meshes
|
6 |
+
from pytorch3d.renderer import (
|
7 |
+
look_at_view_transform,
|
8 |
+
PerspectiveCameras,
|
9 |
+
FoVPerspectiveCameras,
|
10 |
+
PointLights,
|
11 |
+
DirectionalLights,
|
12 |
+
Materials,
|
13 |
+
RasterizationSettings,
|
14 |
+
MeshRenderer,
|
15 |
+
MeshRasterizer,
|
16 |
+
SoftPhongShader,
|
17 |
+
TexturesUV,
|
18 |
+
TexturesVertex,
|
19 |
+
blending,
|
20 |
+
)
|
21 |
+
|
22 |
+
from pytorch3d.ops import interpolate_face_attributes
|
23 |
+
|
24 |
+
from pytorch3d.renderer.blending import (
|
25 |
+
BlendParams,
|
26 |
+
hard_rgb_blend,
|
27 |
+
sigmoid_alpha_blend,
|
28 |
+
softmax_rgb_blend,
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
class SoftSimpleShader(nn.Module):
|
33 |
+
"""
|
34 |
+
Per pixel lighting - the lighting model is applied using the interpolated
|
35 |
+
coordinates and normals for each pixel. The blending function returns the
|
36 |
+
soft aggregated color using all the faces per pixel.
|
37 |
+
|
38 |
+
To use the default values, simply initialize the shader with the desired
|
39 |
+
device e.g.
|
40 |
+
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(
|
44 |
+
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
|
45 |
+
):
|
46 |
+
super().__init__()
|
47 |
+
self.lights = lights if lights is not None else PointLights(device=device)
|
48 |
+
self.materials = (
|
49 |
+
materials if materials is not None else Materials(device=device)
|
50 |
+
)
|
51 |
+
self.cameras = cameras
|
52 |
+
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
53 |
+
|
54 |
+
def to(self, device):
|
55 |
+
# Manually move to device modules which are not subclasses of nn.Module
|
56 |
+
self.cameras = self.cameras.to(device)
|
57 |
+
self.materials = self.materials.to(device)
|
58 |
+
self.lights = self.lights.to(device)
|
59 |
+
return self
|
60 |
+
|
61 |
+
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
62 |
+
|
63 |
+
texels = meshes.sample_textures(fragments)
|
64 |
+
blend_params = kwargs.get("blend_params", self.blend_params)
|
65 |
+
|
66 |
+
cameras = kwargs.get("cameras", self.cameras)
|
67 |
+
if cameras is None:
|
68 |
+
msg = "Cameras must be specified either at initialization \
|
69 |
+
or in the forward pass of SoftPhongShader"
|
70 |
+
raise ValueError(msg)
|
71 |
+
znear = kwargs.get("znear", getattr(cameras, "znear", 1.0))
|
72 |
+
zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0))
|
73 |
+
images = softmax_rgb_blend(
|
74 |
+
texels, fragments, blend_params, znear=znear, zfar=zfar
|
75 |
+
)
|
76 |
+
return images
|
77 |
+
|
78 |
+
|
79 |
+
class Render_3DMM(nn.Module):
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
focal=1015,
|
83 |
+
img_h=500,
|
84 |
+
img_w=500,
|
85 |
+
batch_size=1,
|
86 |
+
device=torch.device("cuda:0"),
|
87 |
+
):
|
88 |
+
super(Render_3DMM, self).__init__()
|
89 |
+
|
90 |
+
self.focal = focal
|
91 |
+
self.img_h = img_h
|
92 |
+
self.img_w = img_w
|
93 |
+
self.device = device
|
94 |
+
self.renderer = self.get_render(batch_size)
|
95 |
+
|
96 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
97 |
+
topo_info = np.load(
|
98 |
+
os.path.join(dir_path, "3DMM", "topology_info.npy"), allow_pickle=True
|
99 |
+
).item()
|
100 |
+
self.tris = torch.as_tensor(topo_info["tris"]).to(self.device)
|
101 |
+
self.vert_tris = torch.as_tensor(topo_info["vert_tris"]).to(self.device)
|
102 |
+
|
103 |
+
def compute_normal(self, geometry):
|
104 |
+
vert_1 = torch.index_select(geometry, 1, self.tris[:, 0])
|
105 |
+
vert_2 = torch.index_select(geometry, 1, self.tris[:, 1])
|
106 |
+
vert_3 = torch.index_select(geometry, 1, self.tris[:, 2])
|
107 |
+
nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 2)
|
108 |
+
tri_normal = nn.functional.normalize(nnorm, dim=2)
|
109 |
+
v_norm = tri_normal[:, self.vert_tris, :].sum(2)
|
110 |
+
vert_normal = v_norm / v_norm.norm(dim=2).unsqueeze(2)
|
111 |
+
return vert_normal
|
112 |
+
|
113 |
+
def get_render(self, batch_size=1):
|
114 |
+
half_s = self.img_w * 0.5
|
115 |
+
R, T = look_at_view_transform(10, 0, 0)
|
116 |
+
R = R.repeat(batch_size, 1, 1)
|
117 |
+
T = torch.zeros((batch_size, 3), dtype=torch.float32).to(self.device)
|
118 |
+
|
119 |
+
cameras = FoVPerspectiveCameras(
|
120 |
+
device=self.device,
|
121 |
+
R=R,
|
122 |
+
T=T,
|
123 |
+
znear=0.01,
|
124 |
+
zfar=20,
|
125 |
+
fov=2 * np.arctan(self.img_w // 2 / self.focal) * 180.0 / np.pi,
|
126 |
+
)
|
127 |
+
lights = PointLights(
|
128 |
+
device=self.device,
|
129 |
+
location=[[0.0, 0.0, 1e5]],
|
130 |
+
ambient_color=[[1, 1, 1]],
|
131 |
+
specular_color=[[0.0, 0.0, 0.0]],
|
132 |
+
diffuse_color=[[0.0, 0.0, 0.0]],
|
133 |
+
)
|
134 |
+
sigma = 1e-4
|
135 |
+
raster_settings = RasterizationSettings(
|
136 |
+
image_size=(self.img_h, self.img_w),
|
137 |
+
blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma / 18.0,
|
138 |
+
faces_per_pixel=2,
|
139 |
+
perspective_correct=False,
|
140 |
+
)
|
141 |
+
blend_params = blending.BlendParams(background_color=[0, 0, 0])
|
142 |
+
renderer = MeshRenderer(
|
143 |
+
rasterizer=MeshRasterizer(raster_settings=raster_settings, cameras=cameras),
|
144 |
+
shader=SoftSimpleShader(
|
145 |
+
lights=lights, blend_params=blend_params, cameras=cameras
|
146 |
+
),
|
147 |
+
)
|
148 |
+
return renderer.to(self.device)
|
149 |
+
|
150 |
+
@staticmethod
|
151 |
+
def Illumination_layer(face_texture, norm, gamma):
|
152 |
+
|
153 |
+
n_b, num_vertex, _ = face_texture.size()
|
154 |
+
n_v_full = n_b * num_vertex
|
155 |
+
gamma = gamma.view(-1, 3, 9).clone()
|
156 |
+
gamma[:, :, 0] += 0.8
|
157 |
+
|
158 |
+
gamma = gamma.permute(0, 2, 1)
|
159 |
+
|
160 |
+
a0 = np.pi
|
161 |
+
a1 = 2 * np.pi / np.sqrt(3.0)
|
162 |
+
a2 = 2 * np.pi / np.sqrt(8.0)
|
163 |
+
c0 = 1 / np.sqrt(4 * np.pi)
|
164 |
+
c1 = np.sqrt(3.0) / np.sqrt(4 * np.pi)
|
165 |
+
c2 = 3 * np.sqrt(5.0) / np.sqrt(12 * np.pi)
|
166 |
+
d0 = 0.5 / np.sqrt(3.0)
|
167 |
+
|
168 |
+
Y0 = torch.ones(n_v_full).to(gamma.device).float() * a0 * c0
|
169 |
+
norm = norm.view(-1, 3)
|
170 |
+
nx, ny, nz = norm[:, 0], norm[:, 1], norm[:, 2]
|
171 |
+
arrH = []
|
172 |
+
|
173 |
+
arrH.append(Y0)
|
174 |
+
arrH.append(-a1 * c1 * ny)
|
175 |
+
arrH.append(a1 * c1 * nz)
|
176 |
+
arrH.append(-a1 * c1 * nx)
|
177 |
+
arrH.append(a2 * c2 * nx * ny)
|
178 |
+
arrH.append(-a2 * c2 * ny * nz)
|
179 |
+
arrH.append(a2 * c2 * d0 * (3 * nz.pow(2) - 1))
|
180 |
+
arrH.append(-a2 * c2 * nx * nz)
|
181 |
+
arrH.append(a2 * c2 * 0.5 * (nx.pow(2) - ny.pow(2)))
|
182 |
+
|
183 |
+
H = torch.stack(arrH, 1)
|
184 |
+
Y = H.view(n_b, num_vertex, 9)
|
185 |
+
lighting = Y.bmm(gamma)
|
186 |
+
|
187 |
+
face_color = face_texture * lighting
|
188 |
+
return face_color
|
189 |
+
|
190 |
+
def forward(self, rott_geometry, texture, diffuse_sh):
|
191 |
+
face_normal = self.compute_normal(rott_geometry)
|
192 |
+
face_color = self.Illumination_layer(texture, face_normal, diffuse_sh)
|
193 |
+
face_color = TexturesVertex(face_color)
|
194 |
+
mesh = Meshes(
|
195 |
+
rott_geometry,
|
196 |
+
self.tris.float().repeat(rott_geometry.shape[0], 1, 1),
|
197 |
+
face_color,
|
198 |
+
)
|
199 |
+
rendered_img = self.renderer(mesh)
|
200 |
+
rendered_img = torch.clamp(rendered_img, 0, 255)
|
201 |
+
|
202 |
+
return rendered_img
|
data_utils/face_tracking/render_land.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import render_util
|
4 |
+
import geo_transform
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def compute_tri_normal(geometry, tris):
|
9 |
+
geometry = geometry.permute(0, 2, 1)
|
10 |
+
tri_1 = tris[:, 0]
|
11 |
+
tri_2 = tris[:, 1]
|
12 |
+
tri_3 = tris[:, 2]
|
13 |
+
|
14 |
+
vert_1 = torch.index_select(geometry, 2, tri_1)
|
15 |
+
vert_2 = torch.index_select(geometry, 2, tri_2)
|
16 |
+
vert_3 = torch.index_select(geometry, 2, tri_3)
|
17 |
+
|
18 |
+
nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 1)
|
19 |
+
normal = nn.functional.normalize(nnorm).permute(0, 2, 1)
|
20 |
+
return normal
|
21 |
+
|
22 |
+
|
23 |
+
class Compute_normal_base(torch.autograd.Function):
|
24 |
+
@staticmethod
|
25 |
+
def forward(ctx, normal):
|
26 |
+
(normal_b,) = render_util.normal_base_forward(normal)
|
27 |
+
ctx.save_for_backward(normal)
|
28 |
+
return normal_b
|
29 |
+
|
30 |
+
@staticmethod
|
31 |
+
def backward(ctx, grad_normal_b):
|
32 |
+
(normal,) = ctx.saved_tensors
|
33 |
+
(grad_normal,) = render_util.normal_base_backward(grad_normal_b, normal)
|
34 |
+
return grad_normal
|
35 |
+
|
36 |
+
|
37 |
+
class Normal_Base(torch.nn.Module):
|
38 |
+
def __init__(self):
|
39 |
+
super(Normal_Base, self).__init__()
|
40 |
+
|
41 |
+
def forward(self, normal):
|
42 |
+
return Compute_normal_base.apply(normal)
|
43 |
+
|
44 |
+
|
45 |
+
def preprocess_render(geometry, euler, trans, cam, tris, vert_tris, ori_img):
|
46 |
+
point_num = geometry.shape[1]
|
47 |
+
rott_geo = geo_transform.euler_trans_geo(geometry, euler, trans)
|
48 |
+
proj_geo = geo_transform.proj_geo(rott_geo, cam)
|
49 |
+
rot_tri_normal = compute_tri_normal(rott_geo, tris)
|
50 |
+
rot_vert_normal = torch.index_select(rot_tri_normal, 1, vert_tris)
|
51 |
+
is_visible = -torch.bmm(
|
52 |
+
rot_vert_normal.reshape(-1, 1, 3),
|
53 |
+
nn.functional.normalize(rott_geo.reshape(-1, 3, 1)),
|
54 |
+
).reshape(-1, point_num)
|
55 |
+
is_visible[is_visible < 0.01] = -1
|
56 |
+
pixel_valid = torch.zeros(
|
57 |
+
(ori_img.shape[0], ori_img.shape[1] * ori_img.shape[2]),
|
58 |
+
dtype=torch.float32,
|
59 |
+
device=ori_img.device,
|
60 |
+
)
|
61 |
+
return rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid
|
62 |
+
|
63 |
+
|
64 |
+
class Render_Face(torch.autograd.Function):
|
65 |
+
@staticmethod
|
66 |
+
def forward(
|
67 |
+
ctx, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid
|
68 |
+
):
|
69 |
+
batch_size, h, w, _ = ori_img.shape
|
70 |
+
ori_img = ori_img.view(batch_size, -1, 3)
|
71 |
+
ori_size = torch.cat(
|
72 |
+
(
|
73 |
+
torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
|
74 |
+
* h,
|
75 |
+
torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
|
76 |
+
* w,
|
77 |
+
),
|
78 |
+
dim=1,
|
79 |
+
).view(-1)
|
80 |
+
tri_index, tri_coord, render, real = render_util.render_face_forward(
|
81 |
+
proj_geo, ori_img, ori_size, texture, nbl, is_visible, tri_inds, pixel_valid
|
82 |
+
)
|
83 |
+
ctx.save_for_backward(
|
84 |
+
ori_img, ori_size, proj_geo, texture, nbl, tri_inds, tri_index, tri_coord
|
85 |
+
)
|
86 |
+
return render, real
|
87 |
+
|
88 |
+
@staticmethod
|
89 |
+
def backward(ctx, grad_render, grad_real):
|
90 |
+
(
|
91 |
+
ori_img,
|
92 |
+
ori_size,
|
93 |
+
proj_geo,
|
94 |
+
texture,
|
95 |
+
nbl,
|
96 |
+
tri_inds,
|
97 |
+
tri_index,
|
98 |
+
tri_coord,
|
99 |
+
) = ctx.saved_tensors
|
100 |
+
grad_proj_geo, grad_texture, grad_nbl = render_util.render_face_backward(
|
101 |
+
grad_render,
|
102 |
+
grad_real,
|
103 |
+
ori_img,
|
104 |
+
ori_size,
|
105 |
+
proj_geo,
|
106 |
+
texture,
|
107 |
+
nbl,
|
108 |
+
tri_inds,
|
109 |
+
tri_index,
|
110 |
+
tri_coord,
|
111 |
+
)
|
112 |
+
return grad_proj_geo, grad_texture, grad_nbl, None, None, None, None
|
113 |
+
|
114 |
+
|
115 |
+
class Render_RGB(nn.Module):
|
116 |
+
def __init__(self):
|
117 |
+
super(Render_RGB, self).__init__()
|
118 |
+
|
119 |
+
def forward(
|
120 |
+
self, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid
|
121 |
+
):
|
122 |
+
return Render_Face.apply(
|
123 |
+
proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid
|
124 |
+
)
|
125 |
+
|
126 |
+
|
127 |
+
def cal_land(proj_geo, is_visible, lands_info, land_num):
|
128 |
+
(land_index,) = render_util.update_contour(lands_info, is_visible, land_num)
|
129 |
+
proj_land = torch.index_select(proj_geo.reshape(-1, 3), 0, land_index)[
|
130 |
+
:, :2
|
131 |
+
].reshape(-1, land_num, 2)
|
132 |
+
return proj_land
|
133 |
+
|
134 |
+
|
135 |
+
class Render_Land(nn.Module):
|
136 |
+
def __init__(self):
|
137 |
+
super(Render_Land, self).__init__()
|
138 |
+
lands_info = np.loadtxt("../data/3DMM/lands_info.txt", dtype=np.int32)
|
139 |
+
self.lands_info = torch.as_tensor(lands_info).cuda()
|
140 |
+
tris = np.loadtxt("../data/3DMM/tris.txt", dtype=np.int64)
|
141 |
+
self.tris = torch.as_tensor(tris).cuda() - 1
|
142 |
+
vert_tris = np.loadtxt("../data/3DMM/vert_tris.txt", dtype=np.int64)
|
143 |
+
self.vert_tris = torch.as_tensor(vert_tris).cuda()
|
144 |
+
self.normal_baser = Normal_Base().cuda()
|
145 |
+
self.renderer = Render_RGB().cuda()
|
146 |
+
|
147 |
+
def render_mesh(self, geometry, euler, trans, cam, ori_img, light):
|
148 |
+
batch_size, h, w, _ = ori_img.shape
|
149 |
+
ori_img = ori_img.view(batch_size, -1, 3)
|
150 |
+
ori_size = torch.cat(
|
151 |
+
(
|
152 |
+
torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
|
153 |
+
* h,
|
154 |
+
torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
|
155 |
+
* w,
|
156 |
+
),
|
157 |
+
dim=1,
|
158 |
+
).view(-1)
|
159 |
+
rott_geo, proj_geo, rot_tri_normal, _, _ = preprocess_render(
|
160 |
+
geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img
|
161 |
+
)
|
162 |
+
tri_nb = self.normal_baser(rot_tri_normal.contiguous())
|
163 |
+
nbl = torch.bmm(
|
164 |
+
tri_nb, (light.reshape(-1, 9, 3))[:, :, 0].unsqueeze(-1).repeat(1, 1, 3)
|
165 |
+
)
|
166 |
+
texture = torch.ones_like(geometry) * 200
|
167 |
+
(render,) = render_util.render_mesh(
|
168 |
+
proj_geo, ori_img, ori_size, texture, nbl, self.tris
|
169 |
+
)
|
170 |
+
return render.view(batch_size, h, w, 3).byte()
|
171 |
+
|
172 |
+
def cal_loss_rgb(self, geometry, euler, trans, cam, ori_img, light, texture, lands):
|
173 |
+
rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid = preprocess_render(
|
174 |
+
geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img
|
175 |
+
)
|
176 |
+
tri_nb = self.normal_baser(rot_tri_normal.contiguous())
|
177 |
+
nbl = torch.bmm(tri_nb, light.reshape(-1, 9, 3))
|
178 |
+
render, real = self.renderer(
|
179 |
+
proj_geo, texture, nbl, ori_img, is_visible, self.tris, pixel_valid
|
180 |
+
)
|
181 |
+
proj_land = cal_land(proj_geo, is_visible, self.lands_info, lands.shape[1])
|
182 |
+
col_minus = torch.norm((render - real).reshape(-1, 3), dim=1).reshape(
|
183 |
+
ori_img.shape[0], -1
|
184 |
+
)
|
185 |
+
col_dis = torch.mean(col_minus * pixel_valid) / (
|
186 |
+
torch.mean(pixel_valid) + 0.00001
|
187 |
+
)
|
188 |
+
land_dists = torch.norm((proj_land - lands).reshape(-1, 2), dim=1).reshape(
|
189 |
+
ori_img.shape[0], -1
|
190 |
+
)
|
191 |
+
lan_dis = torch.mean(land_dists)
|
192 |
+
return col_dis, lan_dis
|
data_utils/face_tracking/util.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def compute_tri_normal(geometry, tris):
|
7 |
+
tri_1 = tris[:, 0]
|
8 |
+
tri_2 = tris[:, 1]
|
9 |
+
tri_3 = tris[:, 2]
|
10 |
+
vert_1 = torch.index_select(geometry, 1, tri_1)
|
11 |
+
vert_2 = torch.index_select(geometry, 1, tri_2)
|
12 |
+
vert_3 = torch.index_select(geometry, 1, tri_3)
|
13 |
+
nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 2)
|
14 |
+
normal = nn.functional.normalize(nnorm)
|
15 |
+
return normal
|
16 |
+
|
17 |
+
|
18 |
+
def euler2rot(euler_angle):
|
19 |
+
batch_size = euler_angle.shape[0]
|
20 |
+
theta = euler_angle[:, 0].reshape(-1, 1, 1)
|
21 |
+
phi = euler_angle[:, 1].reshape(-1, 1, 1)
|
22 |
+
psi = euler_angle[:, 2].reshape(-1, 1, 1)
|
23 |
+
one = torch.ones(batch_size, 1, 1).to(euler_angle.device)
|
24 |
+
zero = torch.zeros(batch_size, 1, 1).to(euler_angle.device)
|
25 |
+
rot_x = torch.cat(
|
26 |
+
(
|
27 |
+
torch.cat((one, zero, zero), 1),
|
28 |
+
torch.cat((zero, theta.cos(), theta.sin()), 1),
|
29 |
+
torch.cat((zero, -theta.sin(), theta.cos()), 1),
|
30 |
+
),
|
31 |
+
2,
|
32 |
+
)
|
33 |
+
rot_y = torch.cat(
|
34 |
+
(
|
35 |
+
torch.cat((phi.cos(), zero, -phi.sin()), 1),
|
36 |
+
torch.cat((zero, one, zero), 1),
|
37 |
+
torch.cat((phi.sin(), zero, phi.cos()), 1),
|
38 |
+
),
|
39 |
+
2,
|
40 |
+
)
|
41 |
+
rot_z = torch.cat(
|
42 |
+
(
|
43 |
+
torch.cat((psi.cos(), -psi.sin(), zero), 1),
|
44 |
+
torch.cat((psi.sin(), psi.cos(), zero), 1),
|
45 |
+
torch.cat((zero, zero, one), 1),
|
46 |
+
),
|
47 |
+
2,
|
48 |
+
)
|
49 |
+
return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
|
50 |
+
|
51 |
+
|
52 |
+
def rot_trans_pts(geometry, rot, trans):
|
53 |
+
rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans[:, :, None]
|
54 |
+
return rott_geo.permute(0, 2, 1)
|
55 |
+
|
56 |
+
|
57 |
+
def cal_lap_loss(tensor_list, weight_list):
|
58 |
+
lap_kernel = (
|
59 |
+
torch.Tensor((-0.5, 1.0, -0.5))
|
60 |
+
.unsqueeze(0)
|
61 |
+
.unsqueeze(0)
|
62 |
+
.float()
|
63 |
+
.to(tensor_list[0].device)
|
64 |
+
)
|
65 |
+
loss_lap = 0
|
66 |
+
for i in range(len(tensor_list)):
|
67 |
+
in_tensor = tensor_list[i]
|
68 |
+
in_tensor = in_tensor.view(-1, 1, in_tensor.shape[-1])
|
69 |
+
out_tensor = F.conv1d(in_tensor, lap_kernel)
|
70 |
+
loss_lap += torch.mean(out_tensor ** 2) * weight_list[i]
|
71 |
+
return loss_lap
|
72 |
+
|
73 |
+
|
74 |
+
def proj_pts(rott_geo, focal_length, cxy):
|
75 |
+
cx, cy = cxy[0], cxy[1]
|
76 |
+
X = rott_geo[:, :, 0]
|
77 |
+
Y = rott_geo[:, :, 1]
|
78 |
+
Z = rott_geo[:, :, 2]
|
79 |
+
fxX = focal_length * X
|
80 |
+
fyY = focal_length * Y
|
81 |
+
proj_x = -fxX / Z + cx
|
82 |
+
proj_y = fyY / Z + cy
|
83 |
+
return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2)
|
84 |
+
|
85 |
+
|
86 |
+
def forward_rott(geometry, euler_angle, trans):
|
87 |
+
rot = euler2rot(euler_angle)
|
88 |
+
rott_geo = rot_trans_pts(geometry, rot, trans)
|
89 |
+
return rott_geo
|
90 |
+
|
91 |
+
|
92 |
+
def forward_transform(geometry, euler_angle, trans, focal_length, cxy):
|
93 |
+
rot = euler2rot(euler_angle)
|
94 |
+
rott_geo = rot_trans_pts(geometry, rot, trans)
|
95 |
+
proj_geo = proj_pts(rott_geo, focal_length, cxy)
|
96 |
+
return proj_geo
|
97 |
+
|
98 |
+
|
99 |
+
def cal_lan_loss(proj_lan, gt_lan):
|
100 |
+
return torch.mean((proj_lan - gt_lan) ** 2)
|
101 |
+
|
102 |
+
|
103 |
+
def cal_col_loss(pred_img, gt_img, img_mask):
|
104 |
+
pred_img = pred_img.float()
|
105 |
+
# loss = torch.sqrt(torch.sum(torch.square(pred_img - gt_img), 3))*img_mask/255
|
106 |
+
loss = (torch.sum(torch.square(pred_img - gt_img), 3)) * img_mask / 255
|
107 |
+
loss = torch.sum(loss, dim=(1, 2)) / torch.sum(img_mask, dim=(1, 2))
|
108 |
+
loss = torch.mean(loss)
|
109 |
+
return loss
|
data_utils/hubert.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import Wav2Vec2Processor, HubertModel
|
2 |
+
import soundfile as sf
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
print("Loading the Wav2Vec2 Processor...")
|
7 |
+
wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
|
8 |
+
print("Loading the HuBERT Model...")
|
9 |
+
hubert_model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
|
10 |
+
|
11 |
+
|
12 |
+
def get_hubert_from_16k_wav(wav_16k_name):
|
13 |
+
speech_16k, _ = sf.read(wav_16k_name)
|
14 |
+
hubert = get_hubert_from_16k_speech(speech_16k)
|
15 |
+
return hubert
|
16 |
+
|
17 |
+
@torch.no_grad()
|
18 |
+
def get_hubert_from_16k_speech(speech, device="cuda:0"):
|
19 |
+
global hubert_model
|
20 |
+
hubert_model = hubert_model.to(device)
|
21 |
+
if speech.ndim ==2:
|
22 |
+
speech = speech[:, 0] # [T, 2] ==> [T,]
|
23 |
+
input_values_all = wav2vec2_processor(speech, return_tensors="pt", sampling_rate=16000).input_values # [1, T]
|
24 |
+
input_values_all = input_values_all.to(device)
|
25 |
+
# For long audio sequence, due to the memory limitation, we cannot process them in one run
|
26 |
+
# HuBERT process the wav with a CNN of stride [5,2,2,2,2,2], making a stride of 320
|
27 |
+
# Besides, the kernel is [10,3,3,3,3,2,2], making 400 a fundamental unit to get 1 time step.
|
28 |
+
# So the CNN is euqal to a big Conv1D with kernel k=400 and stride s=320
|
29 |
+
# We have the equation to calculate out time step: T = floor((t-k)/s)
|
30 |
+
# To prevent overlap, we set each clip length of (K+S*(N-1)), where N is the expected length T of this clip
|
31 |
+
# The start point of next clip should roll back with a length of (kernel-stride) so it is stride * N
|
32 |
+
kernel = 400
|
33 |
+
stride = 320
|
34 |
+
clip_length = stride * 1000
|
35 |
+
num_iter = input_values_all.shape[1] // clip_length
|
36 |
+
expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride
|
37 |
+
res_lst = []
|
38 |
+
for i in range(num_iter):
|
39 |
+
if i == 0:
|
40 |
+
start_idx = 0
|
41 |
+
end_idx = clip_length - stride + kernel
|
42 |
+
else:
|
43 |
+
start_idx = clip_length * i
|
44 |
+
end_idx = start_idx + (clip_length - stride + kernel)
|
45 |
+
input_values = input_values_all[:, start_idx: end_idx]
|
46 |
+
hidden_states = hubert_model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]
|
47 |
+
res_lst.append(hidden_states[0])
|
48 |
+
if num_iter > 0:
|
49 |
+
input_values = input_values_all[:, clip_length * num_iter:]
|
50 |
+
else:
|
51 |
+
input_values = input_values_all
|
52 |
+
# if input_values.shape[1] != 0:
|
53 |
+
if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it
|
54 |
+
hidden_states = hubert_model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]
|
55 |
+
res_lst.append(hidden_states[0])
|
56 |
+
ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024]
|
57 |
+
# assert ret.shape[0] == expected_T
|
58 |
+
assert abs(ret.shape[0] - expected_T) <= 1
|
59 |
+
if ret.shape[0] < expected_T:
|
60 |
+
ret = torch.nn.functional.pad(ret, (0,0,0,expected_T-ret.shape[0]))
|
61 |
+
else:
|
62 |
+
ret = ret[:expected_T]
|
63 |
+
return ret
|
64 |
+
|
65 |
+
def make_even_first_dim(tensor):
|
66 |
+
size = list(tensor.size())
|
67 |
+
if size[0] % 2 == 1:
|
68 |
+
size[0] -= 1
|
69 |
+
return tensor[:size[0]]
|
70 |
+
return tensor
|
71 |
+
|
72 |
+
import soundfile as sf
|
73 |
+
import numpy as np
|
74 |
+
import torch
|
75 |
+
from argparse import ArgumentParser
|
76 |
+
import librosa
|
77 |
+
|
78 |
+
parser = ArgumentParser()
|
79 |
+
parser.add_argument('--wav', type=str, help='')
|
80 |
+
args = parser.parse_args()
|
81 |
+
|
82 |
+
wav_name = args.wav
|
83 |
+
|
84 |
+
speech, sr = sf.read(wav_name)
|
85 |
+
speech_16k = librosa.resample(speech, orig_sr=sr, target_sr=16000)
|
86 |
+
print("SR: {} to {}".format(sr, 16000))
|
87 |
+
# print(speech.shape, speech_16k.shape)
|
88 |
+
|
89 |
+
hubert_hidden = get_hubert_from_16k_speech(speech_16k)
|
90 |
+
hubert_hidden = make_even_first_dim(hubert_hidden).reshape(-1, 2, 1024)
|
91 |
+
np.save(wav_name.replace('.wav', '_hu.npy'), hubert_hidden.detach().numpy())
|
92 |
+
print(hubert_hidden.detach().numpy().shape)
|
data_utils/process.py
ADDED
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import tqdm
|
4 |
+
import json
|
5 |
+
import argparse
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
def extract_audio(path, out_path, sample_rate=16000):
|
10 |
+
|
11 |
+
print(f'[INFO] ===== extract audio from {path} to {out_path} =====')
|
12 |
+
cmd = f'ffmpeg -i {path} -f wav -ar {sample_rate} {out_path}'
|
13 |
+
os.system(cmd)
|
14 |
+
print(f'[INFO] ===== extracted audio =====')
|
15 |
+
|
16 |
+
|
17 |
+
def extract_audio_features(path, mode='wav2vec'):
|
18 |
+
|
19 |
+
print(f'[INFO] ===== extract audio labels for {path} =====')
|
20 |
+
if mode == 'wav2vec':
|
21 |
+
cmd = f'python nerf/asr.py --wav {path} --save_feats'
|
22 |
+
else: # deepspeech
|
23 |
+
cmd = f'python data_utils/deepspeech_features/extract_ds_features.py --input {path}'
|
24 |
+
os.system(cmd)
|
25 |
+
print(f'[INFO] ===== extracted audio labels =====')
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
def extract_images(path, out_path, fps=25):
|
30 |
+
|
31 |
+
print(f'[INFO] ===== extract images from {path} to {out_path} =====')
|
32 |
+
cmd = f'ffmpeg -i {path} -vf fps={fps} -qmin 1 -q:v 1 -start_number 0 {os.path.join(out_path, "%d.jpg")}'
|
33 |
+
os.system(cmd)
|
34 |
+
print(f'[INFO] ===== extracted images =====')
|
35 |
+
|
36 |
+
|
37 |
+
def extract_semantics(ori_imgs_dir, parsing_dir):
|
38 |
+
|
39 |
+
print(f'[INFO] ===== extract semantics from {ori_imgs_dir} to {parsing_dir} =====')
|
40 |
+
cmd = f'python data_utils/face_parsing/test.py --respath={parsing_dir} --imgpath={ori_imgs_dir}'
|
41 |
+
os.system(cmd)
|
42 |
+
print(f'[INFO] ===== extracted semantics =====')
|
43 |
+
|
44 |
+
|
45 |
+
def extract_landmarks(ori_imgs_dir):
|
46 |
+
|
47 |
+
print(f'[INFO] ===== extract face landmarks from {ori_imgs_dir} =====')
|
48 |
+
|
49 |
+
import face_alignment
|
50 |
+
try:
|
51 |
+
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)
|
52 |
+
except:
|
53 |
+
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False)
|
54 |
+
image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
|
55 |
+
for image_path in tqdm.tqdm(image_paths):
|
56 |
+
input = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
57 |
+
input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
|
58 |
+
preds = fa.get_landmarks(input)
|
59 |
+
if len(preds) > 0:
|
60 |
+
lands = preds[0].reshape(-1, 2)[:,:2]
|
61 |
+
np.savetxt(image_path.replace('jpg', 'lms'), lands, '%f')
|
62 |
+
del fa
|
63 |
+
print(f'[INFO] ===== extracted face landmarks =====')
|
64 |
+
|
65 |
+
|
66 |
+
def extract_background(base_dir, ori_imgs_dir):
|
67 |
+
|
68 |
+
print(f'[INFO] ===== extract background image from {ori_imgs_dir} =====')
|
69 |
+
|
70 |
+
from sklearn.neighbors import NearestNeighbors
|
71 |
+
|
72 |
+
image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
|
73 |
+
# only use 1/20 image_paths
|
74 |
+
image_paths = image_paths[::20]
|
75 |
+
# read one image to get H/W
|
76 |
+
tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
77 |
+
h, w = tmp_image.shape[:2]
|
78 |
+
|
79 |
+
# nearest neighbors
|
80 |
+
all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose()
|
81 |
+
distss = []
|
82 |
+
for image_path in tqdm.tqdm(image_paths):
|
83 |
+
parse_img = cv2.imread(image_path.replace('ori_imgs', 'parsing').replace('.jpg', '.png'))
|
84 |
+
bg = (parse_img[..., 0] == 255) & (parse_img[..., 1] == 255) & (parse_img[..., 2] == 255)
|
85 |
+
fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0)
|
86 |
+
nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
|
87 |
+
dists, _ = nbrs.kneighbors(all_xys)
|
88 |
+
distss.append(dists)
|
89 |
+
|
90 |
+
distss = np.stack(distss)
|
91 |
+
max_dist = np.max(distss, 0)
|
92 |
+
max_id = np.argmax(distss, 0)
|
93 |
+
|
94 |
+
bc_pixs = max_dist > 5
|
95 |
+
bc_pixs_id = np.nonzero(bc_pixs)
|
96 |
+
bc_ids = max_id[bc_pixs]
|
97 |
+
|
98 |
+
imgs = []
|
99 |
+
num_pixs = distss.shape[1]
|
100 |
+
for image_path in image_paths:
|
101 |
+
img = cv2.imread(image_path)
|
102 |
+
imgs.append(img)
|
103 |
+
imgs = np.stack(imgs).reshape(-1, num_pixs, 3)
|
104 |
+
|
105 |
+
bc_img = np.zeros((h*w, 3), dtype=np.uint8)
|
106 |
+
bc_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :]
|
107 |
+
bc_img = bc_img.reshape(h, w, 3)
|
108 |
+
|
109 |
+
max_dist = max_dist.reshape(h, w)
|
110 |
+
bc_pixs = max_dist > 5
|
111 |
+
bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose()
|
112 |
+
fg_xys = np.stack(np.nonzero(bc_pixs)).transpose()
|
113 |
+
nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
|
114 |
+
distances, indices = nbrs.kneighbors(bg_xys)
|
115 |
+
bg_fg_xys = fg_xys[indices[:, 0]]
|
116 |
+
bc_img[bg_xys[:, 0], bg_xys[:, 1], :] = bc_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
|
117 |
+
|
118 |
+
cv2.imwrite(os.path.join(base_dir, 'bc.jpg'), bc_img)
|
119 |
+
|
120 |
+
print(f'[INFO] ===== extracted background image =====')
|
121 |
+
|
122 |
+
|
123 |
+
def extract_torso_and_gt(base_dir, ori_imgs_dir):
|
124 |
+
|
125 |
+
print(f'[INFO] ===== extract torso and gt images for {base_dir} =====')
|
126 |
+
|
127 |
+
from scipy.ndimage import binary_erosion, binary_dilation
|
128 |
+
|
129 |
+
# load bg
|
130 |
+
bg_image = cv2.imread(os.path.join(base_dir, 'bc.jpg'), cv2.IMREAD_UNCHANGED)
|
131 |
+
|
132 |
+
image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
|
133 |
+
|
134 |
+
for image_path in tqdm.tqdm(image_paths):
|
135 |
+
# read ori image
|
136 |
+
ori_image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
137 |
+
|
138 |
+
# read semantics
|
139 |
+
seg = cv2.imread(image_path.replace('ori_imgs', 'parsing').replace('.jpg', '.png'))
|
140 |
+
head_part = (seg[..., 0] == 255) & (seg[..., 1] == 0) & (seg[..., 2] == 0)
|
141 |
+
neck_part = (seg[..., 0] == 0) & (seg[..., 1] == 255) & (seg[..., 2] == 0)
|
142 |
+
torso_part = (seg[..., 0] == 0) & (seg[..., 1] == 0) & (seg[..., 2] == 255)
|
143 |
+
bg_part = (seg[..., 0] == 255) & (seg[..., 1] == 255) & (seg[..., 2] == 255)
|
144 |
+
|
145 |
+
# get gt image
|
146 |
+
gt_image = ori_image.copy()
|
147 |
+
gt_image[bg_part] = bg_image[bg_part]
|
148 |
+
cv2.imwrite(image_path.replace('ori_imgs', 'gt_imgs'), gt_image)
|
149 |
+
|
150 |
+
# get torso image
|
151 |
+
torso_image = gt_image.copy() # rgb
|
152 |
+
torso_image[head_part] = bg_image[head_part]
|
153 |
+
torso_alpha = 255 * np.ones((gt_image.shape[0], gt_image.shape[1], 1), dtype=np.uint8) # alpha
|
154 |
+
|
155 |
+
# torso part "vertical" in-painting...
|
156 |
+
L = 8 + 1
|
157 |
+
torso_coords = np.stack(np.nonzero(torso_part), axis=-1) # [M, 2]
|
158 |
+
# lexsort: sort 2D coords first by y then by x,
|
159 |
+
# ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
|
160 |
+
inds = np.lexsort((torso_coords[:, 0], torso_coords[:, 1]))
|
161 |
+
torso_coords = torso_coords[inds]
|
162 |
+
# choose the top pixel for each column
|
163 |
+
u, uid, ucnt = np.unique(torso_coords[:, 1], return_index=True, return_counts=True)
|
164 |
+
top_torso_coords = torso_coords[uid] # [m, 2]
|
165 |
+
# only keep top-is-head pixels
|
166 |
+
top_torso_coords_up = top_torso_coords.copy() - np.array([1, 0])
|
167 |
+
mask = head_part[tuple(top_torso_coords_up.T)]
|
168 |
+
if mask.any():
|
169 |
+
top_torso_coords = top_torso_coords[mask]
|
170 |
+
# get the color
|
171 |
+
top_torso_colors = gt_image[tuple(top_torso_coords.T)] # [m, 3]
|
172 |
+
# construct inpaint coords (vertically up, or minus in x)
|
173 |
+
inpaint_torso_coords = top_torso_coords[None].repeat(L, 0) # [L, m, 2]
|
174 |
+
inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
|
175 |
+
inpaint_torso_coords += inpaint_offsets
|
176 |
+
inpaint_torso_coords = inpaint_torso_coords.reshape(-1, 2) # [Lm, 2]
|
177 |
+
inpaint_torso_colors = top_torso_colors[None].repeat(L, 0) # [L, m, 3]
|
178 |
+
darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
|
179 |
+
inpaint_torso_colors = (inpaint_torso_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
|
180 |
+
# set color
|
181 |
+
torso_image[tuple(inpaint_torso_coords.T)] = inpaint_torso_colors
|
182 |
+
|
183 |
+
inpaint_torso_mask = np.zeros_like(torso_image[..., 0]).astype(bool)
|
184 |
+
inpaint_torso_mask[tuple(inpaint_torso_coords.T)] = True
|
185 |
+
else:
|
186 |
+
inpaint_torso_mask = None
|
187 |
+
|
188 |
+
|
189 |
+
# neck part "vertical" in-painting...
|
190 |
+
push_down = 4
|
191 |
+
L = 48 + push_down + 1
|
192 |
+
|
193 |
+
neck_part = binary_dilation(neck_part, structure=np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=bool), iterations=3)
|
194 |
+
|
195 |
+
neck_coords = np.stack(np.nonzero(neck_part), axis=-1) # [M, 2]
|
196 |
+
# lexsort: sort 2D coords first by y then by x,
|
197 |
+
# ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
|
198 |
+
inds = np.lexsort((neck_coords[:, 0], neck_coords[:, 1]))
|
199 |
+
neck_coords = neck_coords[inds]
|
200 |
+
# choose the top pixel for each column
|
201 |
+
u, uid, ucnt = np.unique(neck_coords[:, 1], return_index=True, return_counts=True)
|
202 |
+
top_neck_coords = neck_coords[uid] # [m, 2]
|
203 |
+
# only keep top-is-head pixels
|
204 |
+
top_neck_coords_up = top_neck_coords.copy() - np.array([1, 0])
|
205 |
+
mask = head_part[tuple(top_neck_coords_up.T)]
|
206 |
+
|
207 |
+
top_neck_coords = top_neck_coords[mask]
|
208 |
+
# push these top down for 4 pixels to make the neck inpainting more natural...
|
209 |
+
offset_down = np.minimum(ucnt[mask] - 1, push_down)
|
210 |
+
top_neck_coords += np.stack([offset_down, np.zeros_like(offset_down)], axis=-1)
|
211 |
+
# get the color
|
212 |
+
top_neck_colors = gt_image[tuple(top_neck_coords.T)] # [m, 3]
|
213 |
+
# construct inpaint coords (vertically up, or minus in x)
|
214 |
+
inpaint_neck_coords = top_neck_coords[None].repeat(L, 0) # [L, m, 2]
|
215 |
+
inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
|
216 |
+
inpaint_neck_coords += inpaint_offsets
|
217 |
+
inpaint_neck_coords = inpaint_neck_coords.reshape(-1, 2) # [Lm, 2]
|
218 |
+
inpaint_neck_colors = top_neck_colors[None].repeat(L, 0) # [L, m, 3]
|
219 |
+
darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
|
220 |
+
inpaint_neck_colors = (inpaint_neck_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
|
221 |
+
# set color
|
222 |
+
torso_image[tuple(inpaint_neck_coords.T)] = inpaint_neck_colors
|
223 |
+
|
224 |
+
# apply blurring to the inpaint area to avoid vertical-line artifects...
|
225 |
+
inpaint_mask = np.zeros_like(torso_image[..., 0]).astype(bool)
|
226 |
+
inpaint_mask[tuple(inpaint_neck_coords.T)] = True
|
227 |
+
|
228 |
+
blur_img = torso_image.copy()
|
229 |
+
blur_img = cv2.GaussianBlur(blur_img, (5, 5), cv2.BORDER_DEFAULT)
|
230 |
+
|
231 |
+
torso_image[inpaint_mask] = blur_img[inpaint_mask]
|
232 |
+
|
233 |
+
# set mask
|
234 |
+
mask = (neck_part | torso_part | inpaint_mask)
|
235 |
+
if inpaint_torso_mask is not None:
|
236 |
+
mask = mask | inpaint_torso_mask
|
237 |
+
torso_image[~mask] = 0
|
238 |
+
torso_alpha[~mask] = 0
|
239 |
+
|
240 |
+
cv2.imwrite(image_path.replace('ori_imgs', 'torso_imgs').replace('.jpg', '.png'), np.concatenate([torso_image, torso_alpha], axis=-1))
|
241 |
+
|
242 |
+
print(f'[INFO] ===== extracted torso and gt images =====')
|
243 |
+
|
244 |
+
|
245 |
+
def face_tracking(ori_imgs_dir):
|
246 |
+
|
247 |
+
print(f'[INFO] ===== perform face tracking =====')
|
248 |
+
|
249 |
+
image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
|
250 |
+
|
251 |
+
# read one image to get H/W
|
252 |
+
tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
253 |
+
h, w = tmp_image.shape[:2]
|
254 |
+
|
255 |
+
cmd = f'python data_utils/face_tracking/face_tracker.py --path={ori_imgs_dir} --img_h={h} --img_w={w} --frame_num={len(image_paths)}'
|
256 |
+
|
257 |
+
os.system(cmd)
|
258 |
+
|
259 |
+
print(f'[INFO] ===== finished face tracking =====')
|
260 |
+
|
261 |
+
|
262 |
+
def save_transforms(base_dir, ori_imgs_dir):
|
263 |
+
print(f'[INFO] ===== save transforms =====')
|
264 |
+
|
265 |
+
import torch
|
266 |
+
|
267 |
+
image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
|
268 |
+
|
269 |
+
# read one image to get H/W
|
270 |
+
tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
271 |
+
h, w = tmp_image.shape[:2]
|
272 |
+
|
273 |
+
params_dict = torch.load(os.path.join(base_dir, 'track_params.pt'))
|
274 |
+
focal_len = params_dict['focal']
|
275 |
+
euler_angle = params_dict['euler']
|
276 |
+
trans = params_dict['trans'] / 10.0
|
277 |
+
valid_num = euler_angle.shape[0]
|
278 |
+
|
279 |
+
def euler2rot(euler_angle):
|
280 |
+
batch_size = euler_angle.shape[0]
|
281 |
+
theta = euler_angle[:, 0].reshape(-1, 1, 1)
|
282 |
+
phi = euler_angle[:, 1].reshape(-1, 1, 1)
|
283 |
+
psi = euler_angle[:, 2].reshape(-1, 1, 1)
|
284 |
+
one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device)
|
285 |
+
zero = torch.zeros((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device)
|
286 |
+
rot_x = torch.cat((
|
287 |
+
torch.cat((one, zero, zero), 1),
|
288 |
+
torch.cat((zero, theta.cos(), theta.sin()), 1),
|
289 |
+
torch.cat((zero, -theta.sin(), theta.cos()), 1),
|
290 |
+
), 2)
|
291 |
+
rot_y = torch.cat((
|
292 |
+
torch.cat((phi.cos(), zero, -phi.sin()), 1),
|
293 |
+
torch.cat((zero, one, zero), 1),
|
294 |
+
torch.cat((phi.sin(), zero, phi.cos()), 1),
|
295 |
+
), 2)
|
296 |
+
rot_z = torch.cat((
|
297 |
+
torch.cat((psi.cos(), -psi.sin(), zero), 1),
|
298 |
+
torch.cat((psi.sin(), psi.cos(), zero), 1),
|
299 |
+
torch.cat((zero, zero, one), 1)
|
300 |
+
), 2)
|
301 |
+
return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
|
302 |
+
|
303 |
+
|
304 |
+
# train_val_split = int(valid_num*0.5)
|
305 |
+
# train_val_split = valid_num - 25 * 20 # take the last 20s as valid set.
|
306 |
+
train_val_split = int(valid_num * 10 / 11)
|
307 |
+
|
308 |
+
train_ids = torch.arange(0, train_val_split)
|
309 |
+
val_ids = torch.arange(train_val_split, valid_num)
|
310 |
+
|
311 |
+
rot = euler2rot(euler_angle)
|
312 |
+
rot_inv = rot.permute(0, 2, 1)
|
313 |
+
trans_inv = -torch.bmm(rot_inv, trans.unsqueeze(2))
|
314 |
+
|
315 |
+
pose = torch.eye(4, dtype=torch.float32)
|
316 |
+
save_ids = ['train', 'val']
|
317 |
+
train_val_ids = [train_ids, val_ids]
|
318 |
+
mean_z = -float(torch.mean(trans[:, 2]).item())
|
319 |
+
|
320 |
+
for split in range(2):
|
321 |
+
transform_dict = dict()
|
322 |
+
transform_dict['focal_len'] = float(focal_len[0])
|
323 |
+
transform_dict['cx'] = float(w/2.0)
|
324 |
+
transform_dict['cy'] = float(h/2.0)
|
325 |
+
transform_dict['frames'] = []
|
326 |
+
ids = train_val_ids[split]
|
327 |
+
save_id = save_ids[split]
|
328 |
+
|
329 |
+
for i in ids:
|
330 |
+
i = i.item()
|
331 |
+
frame_dict = dict()
|
332 |
+
frame_dict['img_id'] = i
|
333 |
+
frame_dict['aud_id'] = i
|
334 |
+
|
335 |
+
pose[:3, :3] = rot_inv[i]
|
336 |
+
pose[:3, 3] = trans_inv[i, :, 0]
|
337 |
+
|
338 |
+
frame_dict['transform_matrix'] = pose.numpy().tolist()
|
339 |
+
|
340 |
+
transform_dict['frames'].append(frame_dict)
|
341 |
+
|
342 |
+
with open(os.path.join(base_dir, 'transforms_' + save_id + '.json'), 'w') as fp:
|
343 |
+
json.dump(transform_dict, fp, indent=2, separators=(',', ': '))
|
344 |
+
|
345 |
+
print(f'[INFO] ===== finished saving transforms =====')
|
346 |
+
|
347 |
+
|
348 |
+
if __name__ == '__main__':
|
349 |
+
parser = argparse.ArgumentParser()
|
350 |
+
parser.add_argument('path', type=str, help="path to video file")
|
351 |
+
parser.add_argument('--task', type=int, default=-1, help="-1 means all")
|
352 |
+
parser.add_argument('--asr', type=str, default='deepspeech', help="wav2vec or deepspeech")
|
353 |
+
|
354 |
+
opt = parser.parse_args()
|
355 |
+
|
356 |
+
base_dir = os.path.dirname(opt.path)
|
357 |
+
|
358 |
+
wav_path = os.path.join(base_dir, 'aud.wav')
|
359 |
+
ori_imgs_dir = os.path.join(base_dir, 'ori_imgs')
|
360 |
+
parsing_dir = os.path.join(base_dir, 'parsing')
|
361 |
+
gt_imgs_dir = os.path.join(base_dir, 'gt_imgs')
|
362 |
+
torso_imgs_dir = os.path.join(base_dir, 'torso_imgs')
|
363 |
+
|
364 |
+
os.makedirs(ori_imgs_dir, exist_ok=True)
|
365 |
+
os.makedirs(parsing_dir, exist_ok=True)
|
366 |
+
os.makedirs(gt_imgs_dir, exist_ok=True)
|
367 |
+
os.makedirs(torso_imgs_dir, exist_ok=True)
|
368 |
+
|
369 |
+
|
370 |
+
# extract audio
|
371 |
+
if opt.task == -1 or opt.task == 1:
|
372 |
+
extract_audio(opt.path, wav_path)
|
373 |
+
|
374 |
+
# extract audio features
|
375 |
+
if opt.task == -1 or opt.task == 2:
|
376 |
+
extract_audio_features(wav_path, mode=opt.asr)
|
377 |
+
|
378 |
+
# extract images
|
379 |
+
if opt.task == -1 or opt.task == 3:
|
380 |
+
extract_images(opt.path, ori_imgs_dir)
|
381 |
+
|
382 |
+
# face parsing
|
383 |
+
if opt.task == -1 or opt.task == 4:
|
384 |
+
extract_semantics(ori_imgs_dir, parsing_dir)
|
385 |
+
|
386 |
+
# extract bg
|
387 |
+
if opt.task == -1 or opt.task == 5:
|
388 |
+
extract_background(base_dir, ori_imgs_dir)
|
389 |
+
|
390 |
+
# extract torso images and gt_images
|
391 |
+
if opt.task == -1 or opt.task == 6:
|
392 |
+
extract_torso_and_gt(base_dir, ori_imgs_dir)
|
393 |
+
|
394 |
+
# extract face landmarks
|
395 |
+
if opt.task == -1 or opt.task == 7:
|
396 |
+
extract_landmarks(ori_imgs_dir)
|
397 |
+
|
398 |
+
# face tracking
|
399 |
+
if opt.task == -1 or opt.task == 8:
|
400 |
+
face_tracking(ori_imgs_dir)
|
401 |
+
|
402 |
+
# save transforms.json
|
403 |
+
if opt.task == -1 or opt.task == 9:
|
404 |
+
save_transforms(base_dir, ori_imgs_dir)
|
405 |
+
|
data_utils/wav2mel.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import librosa.filters
|
3 |
+
import numpy as np
|
4 |
+
from scipy import signal
|
5 |
+
from wav2mel_hparams import hparams as hp
|
6 |
+
from librosa.core.audio import resample
|
7 |
+
import soundfile as sf
|
8 |
+
|
9 |
+
def load_wav(path, sr):
|
10 |
+
return librosa.core.load(path, sr=sr)
|
11 |
+
|
12 |
+
def preemphasis(wav, k, preemphasize=True):
|
13 |
+
if preemphasize:
|
14 |
+
return signal.lfilter([1, -k], [1], wav)
|
15 |
+
return wav
|
16 |
+
|
17 |
+
def inv_preemphasis(wav, k, inv_preemphasize=True):
|
18 |
+
if inv_preemphasize:
|
19 |
+
return signal.lfilter([1], [1, -k], wav)
|
20 |
+
return wav
|
21 |
+
|
22 |
+
def get_hop_size():
|
23 |
+
hop_size = hp.hop_size
|
24 |
+
if hop_size is None:
|
25 |
+
assert hp.frame_shift_ms is not None
|
26 |
+
hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
|
27 |
+
return hop_size
|
28 |
+
|
29 |
+
def linearspectrogram(wav):
|
30 |
+
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
|
31 |
+
S = _amp_to_db(np.abs(D)) - hp.ref_level_db
|
32 |
+
|
33 |
+
if hp.signal_normalization:
|
34 |
+
return _normalize(S)
|
35 |
+
return S
|
36 |
+
|
37 |
+
def melspectrogram(wav):
|
38 |
+
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
|
39 |
+
S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
|
40 |
+
|
41 |
+
if hp.signal_normalization:
|
42 |
+
return _normalize(S)
|
43 |
+
return S
|
44 |
+
|
45 |
+
def _stft(y):
|
46 |
+
return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
|
47 |
+
|
48 |
+
##########################################################
|
49 |
+
#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
|
50 |
+
def num_frames(length, fsize, fshift):
|
51 |
+
"""Compute number of time frames of spectrogram
|
52 |
+
"""
|
53 |
+
pad = (fsize - fshift)
|
54 |
+
if length % fshift == 0:
|
55 |
+
M = (length + pad * 2 - fsize) // fshift + 1
|
56 |
+
else:
|
57 |
+
M = (length + pad * 2 - fsize) // fshift + 2
|
58 |
+
return M
|
59 |
+
|
60 |
+
|
61 |
+
def pad_lr(x, fsize, fshift):
|
62 |
+
"""Compute left and right padding
|
63 |
+
"""
|
64 |
+
M = num_frames(len(x), fsize, fshift)
|
65 |
+
pad = (fsize - fshift)
|
66 |
+
T = len(x) + 2 * pad
|
67 |
+
r = (M - 1) * fshift + fsize - T
|
68 |
+
return pad, pad + r
|
69 |
+
##########################################################
|
70 |
+
#Librosa correct padding
|
71 |
+
def librosa_pad_lr(x, fsize, fshift):
|
72 |
+
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
|
73 |
+
|
74 |
+
# Conversions
|
75 |
+
_mel_basis = None
|
76 |
+
|
77 |
+
def _linear_to_mel(spectogram):
|
78 |
+
global _mel_basis
|
79 |
+
if _mel_basis is None:
|
80 |
+
_mel_basis = _build_mel_basis()
|
81 |
+
return np.dot(_mel_basis, spectogram)
|
82 |
+
|
83 |
+
def _build_mel_basis():
|
84 |
+
assert hp.fmax <= hp.sample_rate // 2
|
85 |
+
return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels,
|
86 |
+
fmin=hp.fmin, fmax=hp.fmax)
|
87 |
+
|
88 |
+
def _amp_to_db(x):
|
89 |
+
min_level = np.exp(hp.min_level_db / 20 * np.log(10))
|
90 |
+
return 20 * np.log10(np.maximum(min_level, x))
|
91 |
+
|
92 |
+
def _db_to_amp(x):
|
93 |
+
return np.power(10.0, (x) * 0.05)
|
94 |
+
|
95 |
+
def _normalize(S):
|
96 |
+
if hp.allow_clipping_in_normalization:
|
97 |
+
if hp.symmetric_mels:
|
98 |
+
return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
|
99 |
+
-hp.max_abs_value, hp.max_abs_value)
|
100 |
+
else:
|
101 |
+
return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
|
102 |
+
|
103 |
+
assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
|
104 |
+
if hp.symmetric_mels:
|
105 |
+
return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
|
106 |
+
else:
|
107 |
+
return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
|
108 |
+
|
109 |
+
def _denormalize(D):
|
110 |
+
if hp.allow_clipping_in_normalization:
|
111 |
+
if hp.symmetric_mels:
|
112 |
+
return (((np.clip(D, -hp.max_abs_value,
|
113 |
+
hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
|
114 |
+
+ hp.min_level_db)
|
115 |
+
else:
|
116 |
+
return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
|
117 |
+
|
118 |
+
if hp.symmetric_mels:
|
119 |
+
return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
|
120 |
+
else:
|
121 |
+
return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
def wav2mel(wav, sr):
|
126 |
+
wav16k = resample(wav, orig_sr=sr, target_sr=16000)
|
127 |
+
# print('wav16k', wav16k.shape, wav16k.dtype)
|
128 |
+
mel = melspectrogram(wav16k)
|
129 |
+
# print('mel', mel.shape, mel.dtype)
|
130 |
+
if np.isnan(mel.reshape(-1)).sum() > 0:
|
131 |
+
raise ValueError(
|
132 |
+
'Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
|
133 |
+
# mel.dtype = np.float32
|
134 |
+
mel_chunks = []
|
135 |
+
mel_idx_multiplier = 80. / 25
|
136 |
+
mel_step_size = 8
|
137 |
+
i = start_idx = 0
|
138 |
+
while start_idx < len(mel[0]):
|
139 |
+
start_idx = int(i * mel_idx_multiplier)
|
140 |
+
if start_idx + mel_step_size // 2 > len(mel[0]):
|
141 |
+
mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
|
142 |
+
elif start_idx - mel_step_size // 2 < 0:
|
143 |
+
mel_chunks.append(mel[:, :mel_step_size])
|
144 |
+
else:
|
145 |
+
mel_chunks.append(mel[:, start_idx - mel_step_size // 2 : start_idx + mel_step_size // 2])
|
146 |
+
i += 1
|
147 |
+
return mel_chunks
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
if __name__ == '__main__':
|
152 |
+
import argparse
|
153 |
+
|
154 |
+
parser = argparse.ArgumentParser()
|
155 |
+
parser.add_argument('--wav', type=str, default='')
|
156 |
+
parser.add_argument('--save_feats', action='store_true')
|
157 |
+
|
158 |
+
opt = parser.parse_args()
|
159 |
+
|
160 |
+
wav, sr = librosa.core.load(opt.wav)
|
161 |
+
mel_chunks = np.array(wav2mel(wav.T, sr))
|
162 |
+
print(mel_chunks.shape, mel_chunks.transpose(0,2,1).shape)
|
163 |
+
|
164 |
+
if opt.save_feats:
|
165 |
+
save_path = opt.wav.replace('.wav', '_mel.npy')
|
166 |
+
np.save(save_path, mel_chunks.transpose(0,2,1))
|
167 |
+
print(f"[INFO] saved logits to {save_path}")
|
data_utils/wav2mel_hparams.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class HParams:
|
2 |
+
def __init__(self, **kwargs):
|
3 |
+
self.data = {}
|
4 |
+
|
5 |
+
for key, value in kwargs.items():
|
6 |
+
self.data[key] = value
|
7 |
+
|
8 |
+
def __getattr__(self, key):
|
9 |
+
if key not in self.data:
|
10 |
+
raise AttributeError("'HParams' object has no attribute %s" % key)
|
11 |
+
return self.data[key]
|
12 |
+
|
13 |
+
def set_hparam(self, key, value):
|
14 |
+
self.data[key] = value
|
15 |
+
|
16 |
+
# Default hyperparameters
|
17 |
+
hparams = HParams(
|
18 |
+
num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
|
19 |
+
# network
|
20 |
+
rescale=True, # Whether to rescale audio prior to preprocessing
|
21 |
+
rescaling_max=0.9, # Rescaling value
|
22 |
+
|
23 |
+
# Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
|
24 |
+
# It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
|
25 |
+
# Does not work if n_ffit is not multiple of hop_size!!
|
26 |
+
use_lws=False,
|
27 |
+
|
28 |
+
n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
|
29 |
+
hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
|
30 |
+
win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
|
31 |
+
sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i <filename>)
|
32 |
+
|
33 |
+
frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
|
34 |
+
|
35 |
+
# Mel and Linear spectrograms normalization/scaling and clipping
|
36 |
+
signal_normalization=True,
|
37 |
+
# Whether to normalize mel spectrograms to some predefined range (following below parameters)
|
38 |
+
allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
|
39 |
+
symmetric_mels=True,
|
40 |
+
# Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
|
41 |
+
# faster and cleaner convergence)
|
42 |
+
max_abs_value=4.,
|
43 |
+
# max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
|
44 |
+
# be too big to avoid gradient explosion,
|
45 |
+
# not too small for fast convergence)
|
46 |
+
# Contribution by @begeekmyfriend
|
47 |
+
# Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
|
48 |
+
# levels. Also allows for better G&L phase reconstruction)
|
49 |
+
preemphasize=True, # whether to apply filter
|
50 |
+
preemphasis=0.97, # filter coefficient.
|
51 |
+
|
52 |
+
# Limits
|
53 |
+
min_level_db=-100,
|
54 |
+
ref_level_db=20,
|
55 |
+
fmin=65,
|
56 |
+
# Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
|
57 |
+
# test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
|
58 |
+
fmax=6000, # To be increased/reduced depending on data.
|
59 |
+
|
60 |
+
###################### Our training parameters #################################
|
61 |
+
img_size=96,
|
62 |
+
fps=25,
|
63 |
+
|
64 |
+
batch_size=16,
|
65 |
+
initial_learning_rate=1e-4,
|
66 |
+
nepochs=200000000000000000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
|
67 |
+
num_workers=16,
|
68 |
+
checkpoint_interval=3000,
|
69 |
+
eval_interval=3000,
|
70 |
+
save_optimizer_state=True,
|
71 |
+
|
72 |
+
syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence.
|
73 |
+
syncnet_batch_size=64,
|
74 |
+
syncnet_lr=1e-4,
|
75 |
+
syncnet_eval_interval=10000,
|
76 |
+
syncnet_checkpoint_interval=10000,
|
77 |
+
|
78 |
+
disc_wt=0.07,
|
79 |
+
disc_initial_learning_rate=1e-4,
|
80 |
+
)
|
data_utils/wav2vec.py
ADDED
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from transformers import AutoModelForCTC, AutoProcessor
|
6 |
+
|
7 |
+
import pyaudio
|
8 |
+
import soundfile as sf
|
9 |
+
import resampy
|
10 |
+
|
11 |
+
from queue import Queue
|
12 |
+
from threading import Thread, Event
|
13 |
+
|
14 |
+
|
15 |
+
def _read_frame(stream, exit_event, queue, chunk):
|
16 |
+
|
17 |
+
while True:
|
18 |
+
if exit_event.is_set():
|
19 |
+
print(f'[INFO] read frame thread ends')
|
20 |
+
break
|
21 |
+
frame = stream.read(chunk, exception_on_overflow=False)
|
22 |
+
frame = np.frombuffer(frame, dtype=np.int16).astype(np.float32) / 32767 # [chunk]
|
23 |
+
queue.put(frame)
|
24 |
+
|
25 |
+
def _play_frame(stream, exit_event, queue, chunk):
|
26 |
+
|
27 |
+
while True:
|
28 |
+
if exit_event.is_set():
|
29 |
+
print(f'[INFO] play frame thread ends')
|
30 |
+
break
|
31 |
+
frame = queue.get()
|
32 |
+
frame = (frame * 32767).astype(np.int16).tobytes()
|
33 |
+
stream.write(frame, chunk)
|
34 |
+
|
35 |
+
class ASR:
|
36 |
+
def __init__(self, opt):
|
37 |
+
|
38 |
+
self.opt = opt
|
39 |
+
|
40 |
+
self.play = opt.asr_play
|
41 |
+
|
42 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
43 |
+
self.fps = opt.fps # 20 ms per frame
|
44 |
+
self.sample_rate = 16000
|
45 |
+
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
|
46 |
+
self.mode = 'live' if opt.asr_wav == '' else 'file'
|
47 |
+
|
48 |
+
if 'esperanto' in self.opt.asr_model:
|
49 |
+
self.audio_dim = 44
|
50 |
+
elif 'deepspeech' in self.opt.asr_model:
|
51 |
+
self.audio_dim = 29
|
52 |
+
else:
|
53 |
+
self.audio_dim = 32
|
54 |
+
|
55 |
+
# prepare context cache
|
56 |
+
# each segment is (stride_left + ctx + stride_right) * 20ms, latency should be (ctx + stride_right) * 20ms
|
57 |
+
self.context_size = opt.m
|
58 |
+
self.stride_left_size = opt.l
|
59 |
+
self.stride_right_size = opt.r
|
60 |
+
self.text = '[START]\n'
|
61 |
+
self.terminated = False
|
62 |
+
self.frames = []
|
63 |
+
|
64 |
+
# pad left frames
|
65 |
+
if self.stride_left_size > 0:
|
66 |
+
self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size)
|
67 |
+
|
68 |
+
|
69 |
+
self.exit_event = Event()
|
70 |
+
self.audio_instance = pyaudio.PyAudio()
|
71 |
+
|
72 |
+
# create input stream
|
73 |
+
if self.mode == 'file':
|
74 |
+
self.file_stream = self.create_file_stream()
|
75 |
+
else:
|
76 |
+
# start a background process to read frames
|
77 |
+
self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk)
|
78 |
+
self.queue = Queue()
|
79 |
+
self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk))
|
80 |
+
|
81 |
+
# play out the audio too...?
|
82 |
+
if self.play:
|
83 |
+
self.output_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=False, output=True, frames_per_buffer=self.chunk)
|
84 |
+
self.output_queue = Queue()
|
85 |
+
self.process_play_frame = Thread(target=_play_frame, args=(self.output_stream, self.exit_event, self.output_queue, self.chunk))
|
86 |
+
|
87 |
+
# current location of audio
|
88 |
+
self.idx = 0
|
89 |
+
|
90 |
+
# create wav2vec model
|
91 |
+
print(f'[INFO] loading ASR model {self.opt.asr_model}...')
|
92 |
+
self.processor = AutoProcessor.from_pretrained(opt.asr_model)
|
93 |
+
self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device)
|
94 |
+
|
95 |
+
# prepare to save logits
|
96 |
+
if self.opt.asr_save_feats:
|
97 |
+
self.all_feats = []
|
98 |
+
|
99 |
+
# the extracted features
|
100 |
+
# use a loop queue to efficiently record endless features: [f--t---][-------][-------]
|
101 |
+
self.feat_buffer_size = 4
|
102 |
+
self.feat_buffer_idx = 0
|
103 |
+
self.feat_queue = torch.zeros(self.feat_buffer_size * self.context_size, self.audio_dim, dtype=torch.float32, device=self.device)
|
104 |
+
|
105 |
+
# TODO: hard coded 16 and 8 window size...
|
106 |
+
self.front = self.feat_buffer_size * self.context_size - 8 # fake padding
|
107 |
+
self.tail = 8
|
108 |
+
# attention window...
|
109 |
+
self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding...
|
110 |
+
|
111 |
+
# warm up steps needed: mid + right + window_size + attention_size
|
112 |
+
self.warm_up_steps = self.context_size + self.stride_right_size + 8 + 2 * 3
|
113 |
+
|
114 |
+
self.listening = False
|
115 |
+
self.playing = False
|
116 |
+
|
117 |
+
def listen(self):
|
118 |
+
# start
|
119 |
+
if self.mode == 'live' and not self.listening:
|
120 |
+
print(f'[INFO] starting read frame thread...')
|
121 |
+
self.process_read_frame.start()
|
122 |
+
self.listening = True
|
123 |
+
|
124 |
+
if self.play and not self.playing:
|
125 |
+
print(f'[INFO] starting play frame thread...')
|
126 |
+
self.process_play_frame.start()
|
127 |
+
self.playing = True
|
128 |
+
|
129 |
+
def stop(self):
|
130 |
+
|
131 |
+
self.exit_event.set()
|
132 |
+
|
133 |
+
if self.play:
|
134 |
+
self.output_stream.stop_stream()
|
135 |
+
self.output_stream.close()
|
136 |
+
if self.playing:
|
137 |
+
self.process_play_frame.join()
|
138 |
+
self.playing = False
|
139 |
+
|
140 |
+
if self.mode == 'live':
|
141 |
+
self.input_stream.stop_stream()
|
142 |
+
self.input_stream.close()
|
143 |
+
if self.listening:
|
144 |
+
self.process_read_frame.join()
|
145 |
+
self.listening = False
|
146 |
+
|
147 |
+
|
148 |
+
def __enter__(self):
|
149 |
+
return self
|
150 |
+
|
151 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
152 |
+
|
153 |
+
self.stop()
|
154 |
+
|
155 |
+
if self.mode == 'live':
|
156 |
+
# live mode: also print the result text.
|
157 |
+
self.text += '\n[END]'
|
158 |
+
print(self.text)
|
159 |
+
|
160 |
+
def get_next_feat(self):
|
161 |
+
# return a [1/8, 16] window, for the next input to nerf side.
|
162 |
+
|
163 |
+
while len(self.att_feats) < 8:
|
164 |
+
# [------f+++t-----]
|
165 |
+
if self.front < self.tail:
|
166 |
+
feat = self.feat_queue[self.front:self.tail]
|
167 |
+
# [++t-----------f+]
|
168 |
+
else:
|
169 |
+
feat = torch.cat([self.feat_queue[self.front:], self.feat_queue[:self.tail]], dim=0)
|
170 |
+
|
171 |
+
self.front = (self.front + 2) % self.feat_queue.shape[0]
|
172 |
+
self.tail = (self.tail + 2) % self.feat_queue.shape[0]
|
173 |
+
|
174 |
+
# print(self.front, self.tail, feat.shape)
|
175 |
+
|
176 |
+
self.att_feats.append(feat.permute(1, 0))
|
177 |
+
|
178 |
+
att_feat = torch.stack(self.att_feats, dim=0) # [8, 44, 16]
|
179 |
+
|
180 |
+
# discard old
|
181 |
+
self.att_feats = self.att_feats[1:]
|
182 |
+
|
183 |
+
return att_feat
|
184 |
+
|
185 |
+
def run_step(self):
|
186 |
+
|
187 |
+
if self.terminated:
|
188 |
+
return
|
189 |
+
|
190 |
+
# get a frame of audio
|
191 |
+
frame = self.get_audio_frame()
|
192 |
+
|
193 |
+
# the last frame
|
194 |
+
if frame is None:
|
195 |
+
# terminate, but always run the network for the left frames
|
196 |
+
self.terminated = True
|
197 |
+
else:
|
198 |
+
self.frames.append(frame)
|
199 |
+
# put to output
|
200 |
+
if self.play:
|
201 |
+
self.output_queue.put(frame)
|
202 |
+
# context not enough, do not run network.
|
203 |
+
if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
|
204 |
+
return
|
205 |
+
|
206 |
+
inputs = np.concatenate(self.frames) # [N * chunk]
|
207 |
+
|
208 |
+
# discard the old part to save memory
|
209 |
+
if not self.terminated:
|
210 |
+
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
|
211 |
+
|
212 |
+
logits, labels, text = self.frame_to_text(inputs)
|
213 |
+
feats = logits # better lips-sync than labels
|
214 |
+
|
215 |
+
# save feats
|
216 |
+
if self.opt.asr_save_feats:
|
217 |
+
self.all_feats.append(feats)
|
218 |
+
|
219 |
+
# record the feats efficiently.. (no concat, constant memory)
|
220 |
+
if not self.terminated:
|
221 |
+
start = self.feat_buffer_idx * self.context_size
|
222 |
+
end = start + feats.shape[0]
|
223 |
+
self.feat_queue[start:end] = feats
|
224 |
+
self.feat_buffer_idx = (self.feat_buffer_idx + 1) % self.feat_buffer_size
|
225 |
+
|
226 |
+
# very naive, just concat the text output.
|
227 |
+
if text != '':
|
228 |
+
self.text = self.text + ' ' + text
|
229 |
+
|
230 |
+
# will only run once at ternimation
|
231 |
+
if self.terminated:
|
232 |
+
self.text += '\n[END]'
|
233 |
+
print(self.text)
|
234 |
+
if self.opt.asr_save_feats:
|
235 |
+
print(f'[INFO] save all feats for training purpose... ')
|
236 |
+
feats = torch.cat(self.all_feats, dim=0) # [N, C]
|
237 |
+
# print('[INFO] before unfold', feats.shape)
|
238 |
+
window_size = 16
|
239 |
+
padding = window_size // 2
|
240 |
+
feats = feats.view(-1, self.audio_dim).permute(1, 0).contiguous() # [C, M]
|
241 |
+
feats = feats.view(1, self.audio_dim, -1, 1) # [1, C, M, 1]
|
242 |
+
unfold_feats = F.unfold(feats, kernel_size=(window_size, 1), padding=(padding, 0), stride=(2, 1)) # [1, C * window_size, M / 2 + 1]
|
243 |
+
unfold_feats = unfold_feats.view(self.audio_dim, window_size, -1).permute(2, 1, 0).contiguous() # [C, window_size, M / 2 + 1] --> [M / 2 + 1, window_size, C]
|
244 |
+
# print('[INFO] after unfold', unfold_feats.shape)
|
245 |
+
# save to a npy file
|
246 |
+
if 'esperanto' in self.opt.asr_model:
|
247 |
+
output_path = self.opt.asr_wav.replace('.wav', '_eo.npy')
|
248 |
+
else:
|
249 |
+
output_path = self.opt.asr_wav.replace('.wav', '.npy')
|
250 |
+
np.save(output_path, unfold_feats.cpu().numpy())
|
251 |
+
print(f"[INFO] saved logits to {output_path}")
|
252 |
+
|
253 |
+
def create_file_stream(self):
|
254 |
+
|
255 |
+
stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64
|
256 |
+
stream = stream.astype(np.float32)
|
257 |
+
|
258 |
+
if stream.ndim > 1:
|
259 |
+
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
|
260 |
+
stream = stream[:, 0]
|
261 |
+
|
262 |
+
if sample_rate != self.sample_rate:
|
263 |
+
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
|
264 |
+
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
|
265 |
+
|
266 |
+
print(f'[INFO] loaded audio stream {self.opt.asr_wav}: {stream.shape}')
|
267 |
+
|
268 |
+
return stream
|
269 |
+
|
270 |
+
|
271 |
+
def create_pyaudio_stream(self):
|
272 |
+
|
273 |
+
import pyaudio
|
274 |
+
|
275 |
+
print(f'[INFO] creating live audio stream ...')
|
276 |
+
|
277 |
+
audio = pyaudio.PyAudio()
|
278 |
+
|
279 |
+
# get devices
|
280 |
+
info = audio.get_host_api_info_by_index(0)
|
281 |
+
n_devices = info.get('deviceCount')
|
282 |
+
|
283 |
+
for i in range(0, n_devices):
|
284 |
+
if (audio.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
|
285 |
+
name = audio.get_device_info_by_host_api_device_index(0, i).get('name')
|
286 |
+
print(f'[INFO] choose audio device {name}, id {i}')
|
287 |
+
break
|
288 |
+
|
289 |
+
# get stream
|
290 |
+
stream = audio.open(input_device_index=i,
|
291 |
+
format=pyaudio.paInt16,
|
292 |
+
channels=1,
|
293 |
+
rate=self.sample_rate,
|
294 |
+
input=True,
|
295 |
+
frames_per_buffer=self.chunk)
|
296 |
+
|
297 |
+
return audio, stream
|
298 |
+
|
299 |
+
|
300 |
+
def get_audio_frame(self):
|
301 |
+
|
302 |
+
if self.mode == 'file':
|
303 |
+
|
304 |
+
if self.idx < self.file_stream.shape[0]:
|
305 |
+
frame = self.file_stream[self.idx: self.idx + self.chunk]
|
306 |
+
self.idx = self.idx + self.chunk
|
307 |
+
return frame
|
308 |
+
else:
|
309 |
+
return None
|
310 |
+
|
311 |
+
else:
|
312 |
+
|
313 |
+
frame = self.queue.get()
|
314 |
+
# print(f'[INFO] get frame {frame.shape}')
|
315 |
+
|
316 |
+
self.idx = self.idx + self.chunk
|
317 |
+
|
318 |
+
return frame
|
319 |
+
|
320 |
+
|
321 |
+
def frame_to_text(self, frame):
|
322 |
+
# frame: [N * 320], N = (context_size + 2 * stride_size)
|
323 |
+
|
324 |
+
inputs = self.processor(frame, sampling_rate=self.sample_rate, return_tensors="pt", padding=True)
|
325 |
+
|
326 |
+
with torch.no_grad():
|
327 |
+
result = self.model(inputs.input_values.to(self.device))
|
328 |
+
logits = result.logits # [1, N - 1, 32]
|
329 |
+
|
330 |
+
# cut off stride
|
331 |
+
left = max(0, self.stride_left_size)
|
332 |
+
right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input.
|
333 |
+
|
334 |
+
# do not cut right if terminated.
|
335 |
+
if self.terminated:
|
336 |
+
right = logits.shape[1]
|
337 |
+
|
338 |
+
logits = logits[:, left:right]
|
339 |
+
|
340 |
+
# print(frame.shape, inputs.input_values.shape, logits.shape)
|
341 |
+
|
342 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
343 |
+
transcription = self.processor.batch_decode(predicted_ids)[0].lower()
|
344 |
+
|
345 |
+
|
346 |
+
# for esperanto
|
347 |
+
# labels = np.array(['ŭ', '»', 'c', 'ĵ', 'ñ', '”', '„', '“', 'ǔ', 'o', 'ĝ', 'm', 'k', 'd', 'a', 'ŝ', 'z', 'i', '«', '—', '‘', 'ĥ', 'f', 'y', 'h', 'j', '|', 'r', 'u', 'ĉ', 's', '–', 'fi', 'l', 'p', '’', 'g', 'v', 't', 'b', 'n', 'e', '[UNK]', '[PAD]'])
|
348 |
+
|
349 |
+
# labels = np.array([' ', ' ', ' ', '-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z'])
|
350 |
+
# print(''.join(labels[predicted_ids[0].detach().cpu().long().numpy()]))
|
351 |
+
# print(predicted_ids[0])
|
352 |
+
# print(transcription)
|
353 |
+
|
354 |
+
return logits[0], predicted_ids[0], transcription # [N,]
|
355 |
+
|
356 |
+
|
357 |
+
def run(self):
|
358 |
+
|
359 |
+
self.listen()
|
360 |
+
|
361 |
+
while not self.terminated:
|
362 |
+
self.run_step()
|
363 |
+
|
364 |
+
def clear_queue(self):
|
365 |
+
# clear the queue, to reduce potential latency...
|
366 |
+
print(f'[INFO] clear queue')
|
367 |
+
if self.mode == 'live':
|
368 |
+
self.queue.queue.clear()
|
369 |
+
if self.play:
|
370 |
+
self.output_queue.queue.clear()
|
371 |
+
|
372 |
+
def warm_up(self):
|
373 |
+
|
374 |
+
self.listen()
|
375 |
+
|
376 |
+
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
|
377 |
+
t = time.time()
|
378 |
+
for _ in range(self.warm_up_steps):
|
379 |
+
self.run_step()
|
380 |
+
if torch.cuda.is_available():
|
381 |
+
torch.cuda.synchronize()
|
382 |
+
t = time.time() - t
|
383 |
+
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
|
384 |
+
|
385 |
+
self.clear_queue()
|
386 |
+
|
387 |
+
|
388 |
+
|
389 |
+
|
390 |
+
if __name__ == '__main__':
|
391 |
+
import argparse
|
392 |
+
|
393 |
+
parser = argparse.ArgumentParser()
|
394 |
+
parser.add_argument('--wav', type=str, default='')
|
395 |
+
parser.add_argument('--play', action='store_true', help="play out the audio")
|
396 |
+
|
397 |
+
parser.add_argument('--model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto')
|
398 |
+
# parser.add_argument('--model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
|
399 |
+
|
400 |
+
parser.add_argument('--save_feats', action='store_true')
|
401 |
+
# audio FPS
|
402 |
+
parser.add_argument('--fps', type=int, default=50)
|
403 |
+
# sliding window left-middle-right length.
|
404 |
+
parser.add_argument('-l', type=int, default=10)
|
405 |
+
parser.add_argument('-m', type=int, default=50)
|
406 |
+
parser.add_argument('-r', type=int, default=10)
|
407 |
+
|
408 |
+
opt = parser.parse_args()
|
409 |
+
|
410 |
+
# fix
|
411 |
+
opt.asr_wav = opt.wav
|
412 |
+
opt.asr_play = opt.play
|
413 |
+
opt.asr_model = opt.model
|
414 |
+
opt.asr_save_feats = opt.save_feats
|
415 |
+
|
416 |
+
if 'deepspeech' in opt.asr_model:
|
417 |
+
raise ValueError("DeepSpeech features should not use this code to extract...")
|
418 |
+
|
419 |
+
with ASR(opt) as asr:
|
420 |
+
asr.run()
|
tensorflow-models/deepspeech-0_1_0-b90017e8.pb.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8866cb5098e4950e72912e49aa2ac192f279ecdd60f7721aa570b7a543451085
|
3 |
+
size 455423408
|
torch-hub/2DFAN4-cd938726ad.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cd938726adb1f15f361263cce2db9cb820c42585fa8796ec72ce19107f369a46
|
3 |
+
size 96316515
|
torch-hub/s3fd-619a316812.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:619a31681264d3f7f7fc7a16a42cbbe8b23f31a256f75a366e5a1bcd59b33543
|
3 |
+
size 89843225
|