File size: 8,119 Bytes
89cf463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import os.path
import pickle
from multiprocessing.dummy import Process, Manager, Queue
import cv2
import time
from options.hifi_test_options import HifiTestOptions
from face_feature.hifi_image_api import HifiImage

# close onnxruntime warning
import onnxruntime
onnxruntime.set_default_logger_severity(3)


class GenInput(Process):
    def __init__(self, feature_src_list_, frame_queue_in_, frame_queue_out_, video_cap, src_img_path):
        super().__init__()
        self.frame_queue_in = frame_queue_in_
        self.frame_queue_out = frame_queue_out_
        self.feature_src_list = feature_src_list_
        self.src_img_path = src_img_path
        self.video_cap = video_cap
        self.hi = HifiImage(crop_size=256)

    def run(self):
        src_latent, crop_face = self.hi.get_face_feature(self.src_img_path)
        human_feature = [src_latent, crop_face]
        self.feature_src_list.append([human_feature])

        count = index = 0
        while True:
            # import numpy as np
            # frame = np.zeros((1080, 1920, 3), dtype=np.uint8)
            have_frame, frame = self.video_cap.read()
            if not have_frame:
                self.frame_queue_in.put(None)
                print("no more frame")
                # video.release()
                break
            # print(frame.shape)
            self.frame_queue_in.put(frame)



def save_video_ffmpeg(video_path, swap_video_path, model_name=''):
    video_name = os.path.basename(video_path).split('.')[-2]
    # audio_file_path = os.path.join(video_dir, video_name + '.wav')
    audio_file_path = video_path.split('.')[-2] + '.wav'
    if not os.path.exists(audio_file_path):
        print('extract audio')
        os.system(
            'ffmpeg -y -hide_banner -loglevel error -i "'
            + str(video_path)
            + '" -f wav -vn  "'
            + str(audio_file_path)
            + '"'
        )
    else:
        print('audio file exist')
    if os.path.exists(audio_file_path):
        os.rename(swap_video_path, swap_video_path.replace('.mp4', '_no_audio.mp4'))
        print('add audio')
        # start = time.time()
        os.system(
            'ffmpeg -y -hide_banner -loglevel error  -i "'
            + str(swap_video_path.replace('.mp4', '_no_audio.mp4'))
            + '" -i "'
            + str(audio_file_path)
            # + '" -c:v copy "'
            + '" -c:v libx264 "'
            + '"-c:a aac -b:v 40000k "'
            + str(swap_video_path)
            + '"'
        )
        # print('add audio time cost', time.time() - start)
        # print('remove temp')
        os.remove(swap_video_path.replace('.mp4', '_no_audio.mp4'))
    if model_name != '':
        os.rename(swap_video_path, swap_video_path.replace('.mp4', '_%s.mp4' % model_name))
    os.remove(audio_file_path)

def chang_video_resolution(video_path, resize_video_path):
    print('change video resolution to 1080p')
    os.system(
        'ffmpeg -y -hide_banner -loglevel error -i "'
        + str(video_path)
        + '" -vf scale=1080:-1 -c:v libx264 -c:a aac -b:v 20000k "'
        + str(resize_video_path)
        + '"'
    )


class GetOutput(Process):
    def __init__(self, frame_queue_out_, src_video_path, model_name, out_dir, video_fps, video_size, video_frame_count, image_name,
                 align_method, use_gfpgan, sr_weight, use_color_trans=False, color_trans_mode='rct'):
    # def __init__(self, frame_queue_out_, src_video_path, model_name, out_dir, video_info):
        super().__init__()
        self.frame_queue_out = frame_queue_out_
        self.src_video_path = src_video_path
        out_video_name = image_name + '_to_' + os.path.basename(src_video_path).split('.')[-2] + '_' + model_name + '_' + align_method + '.mp4'
        if use_gfpgan:
            out_video_name = out_video_name.replace('.mp4', '_sr_{}.mp4'.format(sr_weight))
        if use_color_trans:
            out_video_name = out_video_name.replace('.mp4', '_'+color_trans_mode+'.mp4')
        self.out_path = os.path.join(out_dir, out_video_name)
        # self.video_info = video_info
        print(self.out_path)
        self.videoWriter = cv2.VideoWriter(self.out_path, cv2.VideoWriter_fourcc(*'mp4v'), video_fps, video_size)
        self.video_frame_count = video_frame_count
        # self.model_name = model_name



    def run(self):
        # import time
        count = 0
        fps_count = 0

        start_time = time.time()
        while True:
            queue_out = self.frame_queue_out.get()
            frame_out = queue_out
            # print("out:", type(queue_out))
            fps_count += 1

            if fps_count % 100 == 0:
                end_time = time.time()
                print('fps: {}'.format(fps_count / (end_time - start_time)))
                start_time = time.time()
                fps_count = 0
            count += 1
            if count % self.video_frame_count == 0:
                break
            self.videoWriter.write(frame_out)
        self.videoWriter.release()
        start_time = time.time()
        save_video_ffmpeg(self.src_video_path, self.out_path)
        print("add audio cost:", time.time() - start_time)



class FaceSwap(Process):
    def __init__(self, feature_src_list_, frame_queue_in_,
                 frame_queue_out_, model_name='', align_method='68', use_gfpgan=True, sr_weight=1.0,color_trans_mode='rct'):
        super().__init__()
        from HifiFaceAPI_parallel_trt_roi_realtime_sr_api import HifiFaceRealTime
        self.hfrt = HifiFaceRealTime(feature_src_list_, frame_queue_in_,
                                     frame_queue_out_, model_name=model_name, align_method=align_method,
                                     use_gfpgan=use_gfpgan, sr_weight=sr_weight, use_color_trans=False, color_trans_mode=color_trans_mode)
    def run(self):
        self.hfrt.forward()


if __name__ == '__main__':
    frame_queue_in = Queue(2)
    frame_queue_out = Queue(2)
    manager = Manager()
    image_feature_src_list = manager.list()
    opt = HifiTestOptions().parse()

    model_name = opt.model_name
    align_method = opt.align_method
    use_gfpgan = opt.use_gfpgan
    sr_weight = opt.sr_weight
    use_color_trans = opt.use_color_trans
    color_trans_mode = opt.color_trans_mode
    print("use_gfpgan:", use_gfpgan, "use use_color_trans:", use_color_trans)

    src_img_path = opt.src_img_path
    image_name = src_img_path.split('/')[-1].split('.')[0]
    video_path = opt.video_path
    print(video_path)
    video_name = video_path.split('/')[-1].split('.')[0]
    output_dir = opt.output_dir
    output_dir = os.path.join(output_dir, video_name)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
    video = cv2.VideoCapture(video_path)
    video_fps = video.get(cv2.CAP_PROP_FPS)
    video_size = (int(video.get(cv2.CAP_PROP_FRAME_WIDTH)),
             int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)))
    video_frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))

    print("ori_video_size:", video_size)
    if video_size != (1080, 1920) and opt.video_to_1080p:
        resize_video_path = video_path.replace('.mp4', '_1080p.mp4')
        if not os.path.exists(resize_video_path):
            chang_video_resolution(video_path, resize_video_path)
        video_path = resize_video_path
        # video_size = (1080, 1920)

    t1 = time.time()
    gi = GenInput(image_feature_src_list, frame_queue_in, frame_queue_out, video, src_img_path)

    go = GetOutput(frame_queue_out, video_path, model_name, output_dir, video_fps, video_size, video_frame_count, image_name,
                   align_method, use_gfpgan, sr_weight, use_color_trans, color_trans_mode)

    fs = FaceSwap(image_feature_src_list, frame_queue_in, frame_queue_out,
                  model_name=model_name, align_method=align_method, use_gfpgan=use_gfpgan, sr_weight=sr_weight, color_trans_mode=color_trans_mode)

    gi.start()
    go.start()
    fs.start()

    gi.join()
    print('gi stop')
    go.join()
    print('go stop')
    fs.join()
    print('fs stop')

    video.release()

    print("time cost:", time.time()-t1)