Uniaff commited on
Commit
6f69912
·
verified ·
1 Parent(s): d627392

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +354 -333
inference.py CHANGED
@@ -1,362 +1,383 @@
1
- from os import listdir, path
 
 
 
 
2
  import numpy as np
3
- import scipy, cv2, os, sys, argparse, audio
4
- import json, subprocess, random, string
5
  from tqdm import tqdm
6
- from glob import glob
7
- import torch, face_detection
8
  from wav2lip_models import Wav2Lip
9
- import platform
10
  from face_parsing import init_parser, swap_regions
11
- from esrgan.upsample import upscale
12
- from esrgan.upsample import load_sr
13
- from basicsr.archs.rrdbnet_arch import RRDBNet
14
  from basicsr.utils.download_util import load_file_from_url
15
 
16
- parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
17
-
18
- parser.add_argument('--checkpoint_path', type=str, default="checkpoints/wav2lip_gan.pth",
19
- help='Name of saved checkpoint to load weights from', required=False)
20
-
21
- parser.add_argument('--segmentation_path', type=str, default="checkpoints/face_segmentation.pth",
22
- help='Name of saved checkpoint of segmentation network', required=False)
23
-
24
- parser.add_argument('--sr_path', type=str, default='weights/4x_BigFace_v3_Clear.pth',
25
- help='Name of saved checkpoint of super-resolution network', required=False)
26
-
27
- parser.add_argument('--face', type=str,
28
- help='Filepath of video/image that contains faces to use', required=True)
29
- parser.add_argument('--audio', type=str,
30
- help='Filepath of video/audio file to use as raw audio source', required=True)
31
- parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.',
32
- default='results/result_voice.mp4')
33
-
34
-
35
- parser.add_argument('--static', type=bool,
36
- help='If True, then use only first video frame for inference', default=False)
37
- parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)',
38
- default=25., required=False)
39
-
40
- parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
41
- help='Padding (top, bottom, left, right). Please adjust to include chin at least')
42
-
43
- parser.add_argument('--face_det_batch_size', type=int,
44
- help='Batch size for face detection', default=16)
45
- parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip model(s)', default=128)
46
-
47
- parser.add_argument('--resize_factor', default=1, type=int,
48
- help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p')
49
-
50
- parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1],
51
- help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. '
52
- 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width')
53
-
54
- parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1],
55
- help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.'
56
- 'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).')
57
-
58
- parser.add_argument('--rotate', default=False, action='store_true',
59
- help='Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.'
60
- 'Use if you get a flipped result, despite feeding a normal looking video')
61
-
62
- parser.add_argument('--nosmooth', default=False, action='store_true',
63
- help='Prevent smoothing face detections over a short temporal window')
64
- parser.add_argument('--no_seg', default=False, action='store_true',
65
- help='Prevent using face segmentation')
66
- parser.add_argument('--no_sr', default=False, action='store_true',
67
- help='Prevent using super resolution')
68
- parser.add_argument('--enhance_face', default=None, choices=['gfpgan','codeformer'],
69
- help='Use GFP-GAN to enhance facial details.')
70
- parser.add_argument('-w', '--fidelity_weight', type=float, default=0.75,
71
- help='Balance the quality and fidelity. Default: 0.75')
72
- parser.add_argument('--save_frames', default=False, action='store_true',
73
- help='Save each frame as an image. Use with caution')
74
- parser.add_argument('--gt_path', type=str,
75
- help='Where to store saved ground truth frames', required=False)
76
- parser.add_argument('--pred_path', type=str,
77
- help='Where to store frames produced by algorithm', required=False)
78
- parser.add_argument('--save_as_video', action="store_true", default=False,
79
- help='Whether to save frames as video', required=False)
80
- parser.add_argument('--image_prefix', type=str, default="",
81
- help='Prefix to save frames with', required=False)
82
-
83
- args = parser.parse_args()
84
- args.img_size = 96
85
-
86
- if os.path.isfile(args.face) and args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
87
- args.static = True
 
 
 
 
 
 
 
 
 
88
 
89
  def get_smoothened_boxes(boxes, T):
90
- for i in range(len(boxes)):
91
- if i + T > len(boxes):
92
- window = boxes[len(boxes) - T:]
93
- else:
94
- window = boxes[i : i + T]
95
- boxes[i] = np.mean(window, axis=0)
96
- return boxes
97
-
98
- def face_detect(images):
99
- detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
100
- flip_input=False, device=device)
101
-
102
- batch_size = args.face_det_batch_size
103
-
104
- while 1:
105
- predictions = []
106
- try:
107
- for i in range(0, len(images), batch_size):
108
- predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
109
- except RuntimeError:
110
- if batch_size == 1:
111
- raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')
112
- batch_size //= 2
113
- print('Recovering from OOM error; New batch size: {}'.format(batch_size))
114
- continue
115
- break
116
-
117
- results = []
118
- pady1, pady2, padx1, padx2 = args.pads
119
- for rect, image in zip(predictions, images):
120
- if rect is None:
121
- continue
122
- y1 = max(0, rect[1] - pady1)
123
- y2 = min(image.shape[0], rect[3] + pady2)
124
- x1 = max(0, rect[0] - padx1)
125
- x2 = min(image.shape[1], rect[2] + padx2)
126
-
127
- results.append([x1, y1, x2, y2])
128
-
129
- boxes = np.array(results)
130
- if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
131
- results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
132
-
133
- del detector
134
- return results
135
-
136
- def datagen(mels):
137
  img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
138
 
139
- # Uncommented code removed for clarity
140
-
141
- reader = read_frames()
142
-
143
- for i, m in enumerate(mels):
144
- try:
145
- frame_to_save = next(reader)
146
- except StopIteration:
147
- reader = read_frames()
148
  frame_to_save = next(reader, None)
149
-
150
- if frame_to_save is not None:
151
- face_detect_result = face_detect([frame_to_save])
152
- if len(face_detect_result) > 0: # Check if face detection was successful
153
- face, coords = face_detect_result[0]
154
- face = cv2.resize(face, (args.img_size, args.img_size))
155
- img_batch.append(face)
156
- mel_batch.append(m)
157
- frame_batch.append(frame_to_save)
158
- coords_batch.append(coords)
 
159
 
160
  if len(img_batch) >= args.wav2lip_batch_size:
161
- img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
 
162
 
163
- img_masked = img_batch.copy()
164
  img_masked[:, args.img_size // 2:] = 0
165
 
166
- img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
167
- mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
168
 
169
- yield img_batch, mel_batch, frame_batch, coords_batch
170
  img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
171
 
172
  if len(img_batch) > 0:
173
- img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
 
174
 
175
- img_masked = img_batch.copy()
176
  img_masked[:, args.img_size // 2:] = 0
177
 
178
- img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
179
- mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
180
-
181
- yield img_batch, mel_batch, frame_batch, coords_batch
182
-
183
- mel_step_size = 16
184
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
185
- print('Using {} for inference.'.format(device))
186
-
187
- def _load(checkpoint_path):
188
- if device == 'cuda':
189
- checkpoint = torch.load(checkpoint_path)
190
- else:
191
- checkpoint = torch.load(checkpoint_path,
192
- map_location=lambda storage, loc: storage)
193
- return checkpoint
194
-
195
- def load_model(path):
196
- model = Wav2Lip()
197
- print("Load checkpoint from: {}".format(path))
198
- checkpoint = _load(path)
199
- s = checkpoint["state_dict"]
200
- new_s = {}
201
- for k, v in s.items():
202
- new_s[k.replace('module.', '')] = v
203
- model.load_state_dict(new_s)
204
-
205
- model = model.to(device)
206
- return model.eval()
207
-
208
- def read_frames():
209
- if args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
210
- face = cv2.imread(args.face)
211
- while 1:
212
- yield face
213
-
214
- video_stream = cv2.VideoCapture(args.face)
215
- fps = video_stream.get(cv2.CAP_PROP_FPS)
216
-
217
- print('Reading video frames from start...')
218
-
219
- while 1:
220
- still_reading, frame = video_stream.read()
221
- if not still_reading:
222
- video_stream.release()
223
- break
224
- if args.resize_factor > 1:
225
- frame = cv2.resize(frame, (frame.shape[1]//args.resize_factor, frame.shape[0]//args.resize_factor))
226
-
227
- if args.rotate:
228
- frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)
229
-
230
- y1, y2, x1, x2 = args.crop
231
- if x2 == -1: x2 = frame.shape[1]
232
- if y2 == -1: y2 = frame.shape[0]
233
-
234
- frame = frame[y1:y2, x1:x2]
 
235
 
236
- yield frame
237
 
238
  def main():
239
- if not os.path.isfile(args.face):
240
- raise ValueError('--face argument must be a valid path to video/image file')
241
-
242
- elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
243
- fps = args.fps
244
- else:
245
- video_stream = cv2.VideoCapture(args.face)
246
- fps = video_stream.get(cv2.CAP_PROP_FPS)
247
- video_stream.release()
248
-
249
-
250
- if not args.audio.endswith('.wav'):
251
- print('Extracting raw audio...')
252
- command = 'ffmpeg -y -i {} -strict -2 {}'.format(args.audio, 'temp/temp.wav')
253
-
254
- subprocess.call(command, shell=True)
255
- args.audio = 'temp/temp.wav'
256
-
257
- wav = audio.load_wav(args.audio, 16000)
258
- mel = audio.melspectrogram(wav)
259
- print(mel.shape)
260
-
261
- if np.isnan(mel.reshape(-1)).sum() > 0:
262
- raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
263
-
264
- mel_chunks = []
265
- mel_idx_multiplier = 80./fps
266
- i = 0
267
- while 1:
268
- start_idx = int(i * mel_idx_multiplier)
269
- if start_idx + mel_step_size > len(mel[0]):
270
- mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
271
- break
272
- mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
273
- i += 1
274
-
275
- print("Length of mel chunks: {}".format(len(mel_chunks)))
276
-
277
- batch_size = args.wav2lip_batch_size
278
- gen = datagen(mel_chunks)
279
-
280
-
281
-
282
- if args.save_as_video:
283
- gt_out = cv2.VideoWriter("temp/gt.avi", cv2.VideoWriter_fourcc(*'DIVX'), fps, (384, 384))
284
- pred_out = cv2.VideoWriter("temp/pred.avi", cv2.VideoWriter_fourcc(*'DIVX'), fps, (96, 96))
285
-
286
- abs_idx = 0
287
- for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen,
288
- total=int(np.ceil(float(len(mel_chunks))/batch_size)))):
289
- if i == 0:
290
- if not args.no_seg==True:
291
- print("Loading segmentation network...")
292
- seg_net = load_file_from_url(
293
- url='https://github.com/GucciFlipFlops1917/wav2lip-hq-updated-ESRGAN/releases/download/v0.0.1/face_segmentation.pth',
294
- model_dir='checkpoints', progress=True, file_name=None)
295
- seg_net = init_parser(args.segmentation_path)
296
- if not args.no_sr==True:
297
- print("Loading super resolution model...")
298
- run_params = load_sr(args.sr_path, device, args.enhance_face)
299
-
300
- model_path = load_file_from_url(
301
- url='https://github.com/GucciFlipFlops1917/wav2lip-hq-updated-ESRGAN/releases/download/v0.0.1/wav2lip_gan.pth',
302
- model_dir='checkpoints', progress=True, file_name=None)
303
- model = load_model(args.checkpoint_path)
304
- print ("Model loaded")
305
-
306
- frame_h, frame_w = next(read_frames()).shape[:-1]
307
- out = cv2.VideoWriter('temp/result.avi',
308
- cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
309
-
310
- img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
311
- mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
312
-
313
- with torch.no_grad():
314
- pred = model(mel_batch, img_batch)
315
-
316
- pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
317
-
318
- for p, f, c in zip(pred, frames, coords):
319
- y1, y2, x1, x2 = c
320
-
321
- if args.save_frames:
322
- if args.save_as_video:
323
- pred_out.write(p.astype(np.uint8))
324
- gt_out.write(cv2.resize(f[y1:y2, x1:x2], (384, 384)))
325
- else:
326
- cv2.imwrite(f"{args.gt_path}/{args.image_prefix}{abs_idx}.png", f[y1:y2, x1:x2])
327
- cv2.imwrite(f"{args.pred_path}/{args.image_prefix}{abs_idx}.png", p)
328
- abs_idx += 1
329
-
330
- if not args.no_sr:
331
- if args.enhance_face==None:
332
- p = upscale(p, 0, run_params)
333
- elif args.enhance_face=='codeformer':
334
- p = upscale(p, 2, [run_params, device, args.fidelity_weight])
335
- elif args.enhance_face=='gfpgan':
336
- p = upscale(p, 1, run_params)
337
- p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
338
-
339
- if args.no_seg==False:
340
- p = swap_regions(f[y1:y2, x1:x2], p, seg_net)
341
-
342
- f[y1:y2, x1:x2] = p
343
- out.write(f)
344
-
345
- out.release()
346
-
347
- command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 'temp/result.avi', args.outfile)
348
- subprocess.call(command, shell=platform.system() != 'Windows')
349
-
350
- if args.save_frames and args.save_as_video:
351
- gt_out.release()
352
- pred_out.release()
353
-
354
- command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 'temp/gt.avi', args.gt_path)
355
- subprocess.call(command, shell=platform.system() != 'Windows')
356
-
357
- command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 'temp/pred.avi', args.pred_path)
358
- subprocess.call(command, shell=platform.system() != 'Windows')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
 
360
 
361
  if __name__ == '__main__':
362
- main()
 
1
+ import os
2
+ import argparse
3
+ import subprocess
4
+ import platform
5
+
6
  import numpy as np
7
+ import cv2
8
+ import torch
9
  from tqdm import tqdm
10
+
11
+ from face_detection import FaceAlignment, LandmarksType
12
  from wav2lip_models import Wav2Lip
 
13
  from face_parsing import init_parser, swap_regions
14
+ from esrgan.upsample import upscale, load_sr
 
 
15
  from basicsr.utils.download_util import load_file_from_url
16
 
17
+ import audio # Предполагается, что это ваш модуль для обработки аудио
18
+
19
+ # Оптимизированные импорты: удалены неиспользуемые библиотеки (scipy, json, random, string, glob)
20
+
21
+
22
+ def parse_arguments():
23
+ parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
24
+
25
+ parser.add_argument('--checkpoint_path', type=str, default="checkpoints/wav2lip_gan.pth",
26
+ help='Name of saved checkpoint to load weights from', required=False)
27
+
28
+ parser.add_argument('--segmentation_path', type=str, default="checkpoints/face_segmentation.pth",
29
+ help='Name of saved checkpoint of segmentation network', required=False)
30
+
31
+ parser.add_argument('--sr_path', type=str, default='weights/4x_BigFace_v3_Clear.pth',
32
+ help='Name of saved checkpoint of super-resolution network', required=False)
33
+
34
+ parser.add_argument('--face', type=str,
35
+ help='Filepath of video/image that contains faces to use', required=True)
36
+ parser.add_argument('--audio', type=str,
37
+ help='Filepath of video/audio file to use as raw audio source', required=True)
38
+ parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.',
39
+ default='results/result_voice.mp4')
40
+
41
+ parser.add_argument('--static', action='store_true',
42
+ help='If set, use only first video frame for inference')
43
+ parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)',
44
+ default=25., required=False)
45
+
46
+ parser.add_argument('--pads', nargs=4, type=int, default=[0, 10, 0, 0],
47
+ help='Padding (top, bottom, left, right). Please adjust to include chin at least')
48
+
49
+ parser.add_argument('--face_det_batch_size', type=int,
50
+ help='Batch size for face detection', default=16)
51
+ parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip model(s)', default=128)
52
+
53
+ parser.add_argument('--resize_factor', default=1, type=int,
54
+ help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p')
55
+
56
+ parser.add_argument('--crop', nargs=4, type=int, default=[0, -1, 0, -1],
57
+ help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. '
58
+ 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width')
59
+
60
+ parser.add_argument('--box', nargs=4, type=int, default=[-1, -1, -1, -1],
61
+ help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.'
62
+ 'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).')
63
+
64
+ parser.add_argument('--rotate', action='store_true',
65
+ help='Sometimes videos taken from a phone can be flipped 90deg. If set, will flip video right by 90deg.'
66
+ 'Use if you get a flipped result, despite feeding a normal looking video')
67
+
68
+ parser.add_argument('--nosmooth', action='store_true',
69
+ help='Prevent smoothing face detections over a short temporal window')
70
+ parser.add_argument('--no_seg', action='store_true',
71
+ help='Prevent using face segmentation')
72
+ parser.add_argument('--no_sr', action='store_true',
73
+ help='Prevent using super resolution')
74
+ parser.add_argument('--enhance_face', choices=['gfpgan','codeformer'],
75
+ help='Use GFP-GAN or CodeFormer to enhance facial details.')
76
+ parser.add_argument('-w', '--fidelity_weight', type=float, default=0.75,
77
+ help='Balance the quality and fidelity. Default: 0.75')
78
+ parser.add_argument('--save_frames', action='store_true',
79
+ help='Save each frame as an image. Use with caution')
80
+ parser.add_argument('--gt_path', type=str,
81
+ help='Where to store saved ground truth frames', required=False)
82
+ parser.add_argument('--pred_path', type=str,
83
+ help='Where to store frames produced by algorithm', required=False)
84
+ parser.add_argument('--save_as_video', action="store_true", default=False,
85
+ help='Whether to save frames as video', required=False)
86
+ parser.add_argument('--image_prefix', type=str, default="",
87
+ help='Prefix to save frames with', required=False)
88
+
89
+ args = parser.parse_args()
90
+
91
+ # Определение, является ли ввод статичным изображением
92
+ if os.path.isfile(args.face) and os.path.splitext(args.face)[1].lower() in ['.jpg', '.png', '.jpeg']:
93
+ args.static = True
94
+
95
+ args.img_size = 96
96
+ return args
97
+
98
 
99
  def get_smoothened_boxes(boxes, T):
100
+ for i in range(len(boxes)):
101
+ window = boxes[max(i - T + 1, 0):i + 1]
102
+ boxes[i] = np.mean(window, axis=0)
103
+ return boxes
104
+
105
+
106
+ def face_detect(detector, images, args):
107
+ predictions = []
108
+ batch_size = args.face_det_batch_size
109
+
110
+ try:
111
+ for i in range(0, len(images), batch_size):
112
+ batch_images = np.array(images[i:i + batch_size])
113
+ predictions.extend(detector.get_detections_for_batch(batch_images))
114
+ except RuntimeError:
115
+ if batch_size == 1:
116
+ raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')
117
+ batch_size //= 2
118
+ print(f'Recovering from OOM error; New batch size: {batch_size}')
119
+ return face_detect(detector, images, args) # Рекурсивный вызов с уменьшенным batch_size
120
+
121
+ results = []
122
+ pady1, pady2, padx1, padx2 = args.pads
123
+ for rect, image in zip(predictions, images):
124
+ if rect is None:
125
+ continue
126
+ y1 = max(0, rect[1] - pady1)
127
+ y2 = min(image.shape[0], rect[3] + pady2)
128
+ x1 = max(0, rect[0] - padx1)
129
+ x2 = min(image.shape[1], rect[2] + padx2)
130
+
131
+ results.append([x1, y1, x2, y2])
132
+
133
+ boxes = np.array(results)
134
+ if not args.nosmooth and len(boxes) > 0:
135
+ boxes = get_smoothened_boxes(boxes, T=5)
136
+
137
+ results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)]
138
+ for image, (x1, y1, x2, y2) in zip(images, boxes)]
139
+
140
+ return results
141
+
142
+
143
+ def datagen(mels, reader, detector, args):
 
 
 
144
  img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
145
 
146
+ for m in mels:
147
+ frame_to_save = next(reader, None)
148
+ if frame_to_save is None:
149
+ reader = read_frames(args.face, args.resize_factor, args.rotate, args.crop)
 
 
 
 
 
150
  frame_to_save = next(reader, None)
151
+ if frame_to_save is None:
152
+ break
153
+
154
+ face_detect_result = face_detect(detector, [frame_to_save], args)
155
+ if len(face_detect_result) > 0: # Check if face detection was successful
156
+ face, coords = face_detect_result[0]
157
+ face = cv2.resize(face, (args.img_size, args.img_size))
158
+ img_batch.append(face)
159
+ mel_batch.append(m)
160
+ frame_batch.append(frame_to_save)
161
+ coords_batch.append(coords)
162
 
163
  if len(img_batch) >= args.wav2lip_batch_size:
164
+ img_batch_np = np.asarray(img_batch)
165
+ mel_batch_np = np.asarray(mel_batch)
166
 
167
+ img_masked = img_batch_np.copy()
168
  img_masked[:, args.img_size // 2:] = 0
169
 
170
+ img_batch_np = np.concatenate((img_masked, img_batch_np), axis=3) / 255.0
171
+ mel_batch_np = mel_batch_np.reshape(len(mel_batch_np), mel_batch_np.shape[1], mel_batch_np.shape[2], 1)
172
 
173
+ yield img_batch_np, mel_batch_np, frame_batch, coords_batch
174
  img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
175
 
176
  if len(img_batch) > 0:
177
+ img_batch_np = np.asarray(img_batch)
178
+ mel_batch_np = np.asarray(mel_batch)
179
 
180
+ img_masked = img_batch_np.copy()
181
  img_masked[:, args.img_size // 2:] = 0
182
 
183
+ img_batch_np = np.concatenate((img_masked, img_batch_np), axis=3) / 255.0
184
+ mel_batch_np = mel_batch_np.reshape(len(mel_batch_np), mel_batch_np.shape[1], mel_batch_np.shape[2], 1)
185
+
186
+ yield img_batch_np, mel_batch_np, frame_batch, coords_batch
187
+
188
+
189
+ def load_checkpoint(checkpoint_path, device):
190
+ if device == 'cuda':
191
+ checkpoint = torch.load(checkpoint_path)
192
+ else:
193
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
194
+ return checkpoint
195
+
196
+
197
+ def load_model(checkpoint_path, device):
198
+ model = Wav2Lip()
199
+ print(f"Loading checkpoint from: {checkpoint_path}")
200
+ checkpoint = load_checkpoint(checkpoint_path, device)
201
+ state_dict = checkpoint["state_dict"]
202
+ new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
203
+ model.load_state_dict(new_state_dict)
204
+ model = model.to(device)
205
+ model.eval()
206
+ return model
207
+
208
+
209
+ def read_frames(face_path, resize_factor, rotate, crop):
210
+ if os.path.splitext(face_path)[1].lower() in ['.jpg', '.png', '.jpeg']:
211
+ face = cv2.imread(face_path)
212
+ if resize_factor > 1:
213
+ face = cv2.resize(face, (face.shape[1]//resize_factor, face.shape[0]//resize_factor))
214
+ if rotate:
215
+ face = cv2.rotate(face, cv2.ROTATE_90_CLOCKWISE)
216
+ y1, y2, x1, x2 = crop
217
+ if x2 == -1: x2 = face.shape[1]
218
+ if y2 == -1: y2 = face.shape[0]
219
+ face = face[y1:y2, x1:x2]
220
+ while True:
221
+ yield face
222
+ else:
223
+ video_stream = cv2.VideoCapture(face_path)
224
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
225
+ print('Reading video frames from start...')
226
+
227
+ while True:
228
+ still_reading, frame = video_stream.read()
229
+ if not still_reading:
230
+ video_stream.release()
231
+ break
232
+ if resize_factor > 1:
233
+ frame = cv2.resize(frame, (frame.shape[1]//resize_factor, frame.shape[0]//resize_factor))
234
+ if rotate:
235
+ frame = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)
236
+ y1, y2, x1, x2 = crop
237
+ if x2 == -1: x2 = frame.shape[1]
238
+ if y2 == -1: y2 = frame.shape[0]
239
+ frame = frame[y1:y2, x1:x2]
240
+ yield frame
241
 
 
242
 
243
  def main():
244
+ args = parse_arguments()
245
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
246
+ print(f'Using {device} for inference.')
247
+
248
+ # Инициализация моделей вне циклов
249
+ detector = FaceAlignment(LandmarksType._2D, flip_input=False, device=device)
250
+
251
+ if not args.no_seg:
252
+ print("Loading segmentation network...")
253
+ seg_net = init_parser(args.segmentation_path)
254
+ else:
255
+ seg_net = None
256
+
257
+ if not args.no_sr:
258
+ print("Loading super resolution model...")
259
+ run_params = load_sr(args.sr_path, device, args.enhance_face)
260
+ else:
261
+ run_params = None
262
+
263
+ model = load_model(args.checkpoint_path, device)
264
+ print("Model loaded")
265
+
266
+ if not os.path.isfile(args.face):
267
+ raise ValueError('--face argument must be a valid path to video/image file')
268
+
269
+ if not args.audio.endswith('.wav'):
270
+ print('Extracting raw audio...')
271
+ temp_wav = 'temp/temp.wav'
272
+ command = f'ffmpeg -y -i "{args.audio}" -strict -2 "{temp_wav}"'
273
+ subprocess.call(command, shell=True)
274
+ args.audio = temp_wav
275
+
276
+ wav = audio.load_wav(args.audio, 16000)
277
+ mel = audio.melspectrogram(wav)
278
+ print(mel.shape)
279
+
280
+ if np.isnan(mel).any():
281
+ raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
282
+
283
+ mel_step_size = 16
284
+ fps = args.fps if args.static else None
285
+
286
+ if not args.static:
287
+ video_stream = cv2.VideoCapture(args.face)
288
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
289
+ video_stream.release()
290
+
291
+ mel_idx_multiplier = 80.0 / fps
292
+ mel_chunks = []
293
+ i = 0
294
+ while True:
295
+ start_idx = int(i * mel_idx_multiplier)
296
+ if start_idx + mel_step_size > mel.shape[1]:
297
+ mel_chunks.append(mel[:, -mel_step_size:])
298
+ break
299
+ mel_chunks.append(mel[:, start_idx:start_idx + mel_step_size])
300
+ i += 1
301
+
302
+ print(f"Length of mel chunks: {len(mel_chunks)}")
303
+
304
+ reader = read_frames(args.face, args.resize_factor, args.rotate, args.crop)
305
+ generator = datagen(mel_chunks, reader, detector, args)
306
+
307
+ if args.save_as_video:
308
+ frame_sample = next(reader)
309
+ frame_h, frame_w = frame_sample.shape[:2]
310
+ out = cv2.VideoWriter('temp/result.avi',
311
+ cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
312
+ if args.save_frames:
313
+ gt_out = cv2.VideoWriter("temp/gt.avi", cv2.VideoWriter_fourcc(*'DIVX'), fps, (384, 384))
314
+ pred_out = cv2.VideoWriter("temp/pred.avi", cv2.VideoWriter_fourcc(*'DIVX'), fps, (96, 96))
315
+ else:
316
+ out = None
317
+ gt_out = None
318
+ pred_out = None
319
+
320
+ abs_idx = 0
321
+ for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(generator,
322
+ total=int(np.ceil(len(mel_chunks)/args.wav2lip_batch_size)))):
323
+ img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
324
+ mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
325
+
326
+ with torch.no_grad():
327
+ pred = model(mel_batch, img_batch)
328
+
329
+ pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.0
330
+
331
+ for p, f, c in zip(pred, frames, coords):
332
+ y1, y2, x1, x2 = c
333
+
334
+ if args.save_frames:
335
+ if args.save_as_video:
336
+ pred_out.write(p.astype(np.uint8))
337
+ gt_resized = cv2.resize(f[y1:y2, x1:x2], (384, 384))
338
+ gt_out.write(gt_resized)
339
+ else:
340
+ if args.gt_path and args.pred_path:
341
+ os.makedirs(args.gt_path, exist_ok=True)
342
+ os.makedirs(args.pred_path, exist_ok=True)
343
+ cv2.imwrite(f"{args.gt_path}/{args.image_prefix}{abs_idx}.png", f[y1:y2, x1:x2])
344
+ cv2.imwrite(f"{args.pred_path}/{args.image_prefix}{abs_idx}.png", p)
345
+ abs_idx += 1
346
+
347
+ if not args.no_sr:
348
+ if args.enhance_face is None:
349
+ p = upscale(p, 0, run_params)
350
+ elif args.enhance_face == 'codeformer':
351
+ p = upscale(p, 2, [run_params, device, args.fidelity_weight])
352
+ elif args.enhance_face == 'gfpgan':
353
+ p = upscale(p, 1, run_params)
354
+
355
+ p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
356
+
357
+ if not args.no_seg and seg_net is not None:
358
+ p = swap_regions(f[y1:y2, x1:x2], p, seg_net)
359
+
360
+ f[y1:y2, x1:x2] = p
361
+ if out:
362
+ out.write(f)
363
+
364
+ if out:
365
+ out.release()
366
+
367
+ # Объединение аудио и видео
368
+ final_command = f'ffmpeg -y -i "{args.audio}" -i "temp/result.avi" -strict -2 -q:v 1 "{args.outfile}"'
369
+ subprocess.call(final_command, shell=(platform.system() != 'Windows'))
370
+
371
+ if args.save_frames and args.save_as_video:
372
+ gt_out.release()
373
+ pred_out.release()
374
+
375
+ gt_video_cmd = f'ffmpeg -y -i "temp/gt.avi" -i "{args.audio}" -strict -2 -q:v 1 "{args.gt_path}"'
376
+ pred_video_cmd = f'ffmpeg -y -i "temp/pred.avi" -i "{args.audio}" -strict -2 -q:v 1 "{args.pred_path}"'
377
+
378
+ subprocess.call(gt_video_cmd, shell=(platform.system() != 'Windows'))
379
+ subprocess.call(pred_video_cmd, shell=(platform.system() != 'Windows'))
380
 
381
 
382
  if __name__ == '__main__':
383
+ main()