DmitrMakeev
commited on
Commit
·
f03fe1a
1
Parent(s):
0b2b527
Upload interface.py
Browse files- tools/interface.py +190 -0
tools/interface.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from skimage import io,img_as_float32
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import subprocess
|
8 |
+
import pandas
|
9 |
+
from models.audio2pose import audio2poseLSTM
|
10 |
+
from scipy.io import wavfile
|
11 |
+
import python_speech_features
|
12 |
+
import pyworld
|
13 |
+
import config
|
14 |
+
import json
|
15 |
+
from scipy.interpolate import interp1d
|
16 |
+
|
17 |
+
def inter_pitch(y,y_flag):
|
18 |
+
frame_num = y.shape[0]
|
19 |
+
i = 0
|
20 |
+
last = -1
|
21 |
+
while(i<frame_num):
|
22 |
+
if y_flag[i] == 0:
|
23 |
+
while True:
|
24 |
+
if y_flag[i]==0:
|
25 |
+
if i == frame_num-1:
|
26 |
+
if last !=-1:
|
27 |
+
y[last+1:] = y[last]
|
28 |
+
i+=1
|
29 |
+
break
|
30 |
+
i+=1
|
31 |
+
else:
|
32 |
+
break
|
33 |
+
if i >= frame_num:
|
34 |
+
break
|
35 |
+
elif last == -1:
|
36 |
+
y[:i] = y[i]
|
37 |
+
else:
|
38 |
+
inter_num = i-last+1
|
39 |
+
fy = np.array([y[last],y[i]])
|
40 |
+
fx = np.linspace(0, 1, num=2)
|
41 |
+
f = interp1d(fx,fy)
|
42 |
+
fx_new = np.linspace(0,1,inter_num)
|
43 |
+
fy_new = f(fx_new)
|
44 |
+
y[last+1:i] = fy_new[1:-1]
|
45 |
+
last = i
|
46 |
+
i+=1
|
47 |
+
|
48 |
+
else:
|
49 |
+
last = i
|
50 |
+
i+=1
|
51 |
+
return y
|
52 |
+
|
53 |
+
|
54 |
+
def load_ckpt(checkpoint_path, generator = None, kp_detector = None, ph2kp = None):
|
55 |
+
checkpoint = torch.load(checkpoint_path)
|
56 |
+
if ph2kp is not None:
|
57 |
+
ph2kp.load_state_dict(checkpoint['ph2kp'])
|
58 |
+
if generator is not None:
|
59 |
+
generator.load_state_dict(checkpoint['generator'])
|
60 |
+
if kp_detector is not None:
|
61 |
+
kp_detector.load_state_dict(checkpoint['kp_detector'])
|
62 |
+
|
63 |
+
def get_img_pose(img_path):
|
64 |
+
processor = config.OPENFACE_POSE_EXTRACTOR_PATH
|
65 |
+
|
66 |
+
tmp_dir = "samples/tmp_dir"
|
67 |
+
os.makedirs((tmp_dir),exist_ok=True)
|
68 |
+
subprocess.call([processor, "-f", img_path, "-out_dir", tmp_dir, "-pose"])
|
69 |
+
|
70 |
+
img_file = os.path.basename(img_path)[:-4]+".csv"
|
71 |
+
csv_file = os.path.join(tmp_dir,img_file)
|
72 |
+
pos_data = pandas.read_csv(csv_file)
|
73 |
+
i = 0
|
74 |
+
pose = [pos_data["pose_Rx"][i], pos_data["pose_Ry"][i], pos_data["pose_Rz"][i],pos_data["pose_Tx"][i], pos_data["pose_Ty"][i], pos_data["pose_Tz"][i]]
|
75 |
+
# pose = [pose]
|
76 |
+
pose = np.array(pose,dtype=np.float32)
|
77 |
+
return pose
|
78 |
+
|
79 |
+
def read_img(path):
|
80 |
+
img = io.imread(path)[:,:,:3]
|
81 |
+
img = cv2.resize(img, (256, 256))
|
82 |
+
# img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
83 |
+
img = np.array(img_as_float32(img))
|
84 |
+
img = img.transpose((2, 0, 1))
|
85 |
+
img = torch.from_numpy(img).unsqueeze(0)
|
86 |
+
return img
|
87 |
+
|
88 |
+
|
89 |
+
def parse_phoneme_file(phoneme_path,use_index = True):
|
90 |
+
with open(phoneme_path,'r') as f:
|
91 |
+
result_text = json.load(f)
|
92 |
+
frame_num = int(result_text[-1]['phones'][-1]['ed']/100*25)
|
93 |
+
phoneset_list = []
|
94 |
+
index = 0
|
95 |
+
|
96 |
+
word_len = len(result_text)
|
97 |
+
word_index = 0
|
98 |
+
phone_index = 0
|
99 |
+
cur_phone_list = result_text[0]["phones"]
|
100 |
+
phone_len = len(cur_phone_list)
|
101 |
+
cur_end = cur_phone_list[0]["ed"]
|
102 |
+
|
103 |
+
phone_list = []
|
104 |
+
|
105 |
+
phoneset_list.append(cur_phone_list[0]["ph"])
|
106 |
+
i = 0
|
107 |
+
while i < frame_num:
|
108 |
+
if i * 4 < cur_end:
|
109 |
+
phone_list.append(cur_phone_list[phone_index]["ph"])
|
110 |
+
i += 1
|
111 |
+
else:
|
112 |
+
phone_index += 1
|
113 |
+
if phone_index >= phone_len:
|
114 |
+
word_index += 1
|
115 |
+
if word_index >= word_len:
|
116 |
+
phone_list.append(cur_phone_list[-1]["ph"])
|
117 |
+
i += 1
|
118 |
+
else:
|
119 |
+
phone_index = 0
|
120 |
+
cur_phone_list = result_text[word_index]["phones"]
|
121 |
+
phone_len = len(cur_phone_list)
|
122 |
+
cur_end = cur_phone_list[phone_index]["ed"]
|
123 |
+
phoneset_list.append(cur_phone_list[phone_index]["ph"])
|
124 |
+
index += 1
|
125 |
+
else:
|
126 |
+
# print(word_index,phone_index)
|
127 |
+
cur_end = cur_phone_list[phone_index]["ed"]
|
128 |
+
phoneset_list.append(cur_phone_list[phone_index]["ph"])
|
129 |
+
index += 1
|
130 |
+
|
131 |
+
with open("phindex.json") as f:
|
132 |
+
ph2index = json.load(f)
|
133 |
+
if use_index:
|
134 |
+
phone_list = [ph2index[p] for p in phone_list]
|
135 |
+
saves = {"phone_list": phone_list}
|
136 |
+
|
137 |
+
return saves
|
138 |
+
|
139 |
+
def get_audio_feature_from_audio(audio_path):
|
140 |
+
sample_rate, audio = wavfile.read(audio_path)
|
141 |
+
if len(audio.shape) == 2:
|
142 |
+
if np.min(audio[:, 0]) <= 0:
|
143 |
+
audio = audio[:, 1]
|
144 |
+
else:
|
145 |
+
audio = audio[:, 0]
|
146 |
+
|
147 |
+
audio = audio - np.mean(audio)
|
148 |
+
audio = audio / np.max(np.abs(audio))
|
149 |
+
a = python_speech_features.mfcc(audio, sample_rate)
|
150 |
+
b = python_speech_features.logfbank(audio, sample_rate)
|
151 |
+
c, _ = pyworld.harvest(audio, sample_rate, frame_period=10)
|
152 |
+
c_flag = (c == 0.0) ^ 1
|
153 |
+
c = inter_pitch(c, c_flag)
|
154 |
+
c = np.expand_dims(c, axis=1)
|
155 |
+
c_flag = np.expand_dims(c_flag, axis=1)
|
156 |
+
frame_num = np.min([a.shape[0], b.shape[0], c.shape[0]])
|
157 |
+
|
158 |
+
cat = np.concatenate([a[:frame_num], b[:frame_num], c[:frame_num], c_flag[:frame_num]], axis=1)
|
159 |
+
return cat
|
160 |
+
|
161 |
+
def get_pose_from_audio(img,audio,audio2pose):
|
162 |
+
|
163 |
+
num_frame = len(audio) // 4
|
164 |
+
|
165 |
+
minv = np.array([-0.6, -0.6, -0.6, -128.0, -128.0, 128.0], dtype=np.float32)
|
166 |
+
maxv = np.array([0.6, 0.6, 0.6, 128.0, 128.0, 384.0], dtype=np.float32)
|
167 |
+
generator = audio2poseLSTM().cuda().eval()
|
168 |
+
|
169 |
+
ckpt_para = torch.load(audio2pose)
|
170 |
+
|
171 |
+
generator.load_state_dict(ckpt_para["generator"])
|
172 |
+
generator.eval()
|
173 |
+
|
174 |
+
|
175 |
+
audio_seq = []
|
176 |
+
for i in range(num_frame):
|
177 |
+
audio_seq.append(audio[i*4:i*4+4])
|
178 |
+
|
179 |
+
audio = torch.from_numpy(np.array(audio_seq,dtype=np.float32)).unsqueeze(0).cuda()
|
180 |
+
|
181 |
+
x = {}
|
182 |
+
x ["img"] = img
|
183 |
+
x["audio"] = audio
|
184 |
+
poses = generator(x)
|
185 |
+
|
186 |
+
poses = poses.cpu().data.numpy()[0]
|
187 |
+
poses = (poses+1)/2*(maxv-minv)+minv
|
188 |
+
|
189 |
+
return poses
|
190 |
+
|