ibrahim313 commited on
Commit
24bde82
·
verified ·
1 Parent(s): 885d086

Upload 12 files

Browse files
Files changed (12) hide show
  1. GUI.py +361 -0
  2. README.md +1 -12
  3. audio.py +179 -0
  4. config.ini +51 -0
  5. degradations.py +764 -0
  6. easy_functions.py +196 -0
  7. enhance.py +22 -0
  8. hparams.py +99 -0
  9. inference.py +781 -0
  10. install.py +96 -0
  11. requirements.txt +18 -0
  12. run.py +496 -0
GUI.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tkinter as tk
2
+ from tkinter import filedialog, ttk
3
+ import configparser
4
+ import os
5
+
6
+ try:
7
+ with open('installed.txt', 'r') as file:
8
+ version = file.read()
9
+ except FileNotFoundError:
10
+ print("SyncKing-Kong does not appear to have installed correctly.")
11
+ print("Please try to install it again.")
12
+ print("https://github.com/anothermartz/Easy-Wav2Lip/issues")
13
+ input()
14
+ exit()
15
+
16
+ print("opening GUI")
17
+
18
+ runfile = 'run.txt'
19
+ if os.path.exists(runfile):
20
+ os.remove(runfile)
21
+
22
+ import webbrowser
23
+
24
+ def open_github_link(event):
25
+ webbrowser.open("https://github.com/anothermartz/Easy-Wav2Lip?tab=readme-ov-file#advanced-tweaking")
26
+
27
+ def read_config():
28
+ # Read the config.ini file
29
+ config = configparser.ConfigParser()
30
+ config.read("config.ini")
31
+ return config
32
+
33
+ def save_config(config):
34
+ # Save the updated config back to config.ini
35
+ with open("config.ini", "w") as config_file:
36
+ config.write(config_file)
37
+
38
+ def open_video_file():
39
+ file_path = filedialog.askopenfilename(title="Select a video file", filetypes=[("All files", "*.*")])
40
+ if file_path:
41
+ video_file_var.set(file_path)
42
+
43
+ def open_vocal_file():
44
+ file_path = filedialog.askopenfilename(title="Select a vocal file", filetypes=[("All files", "*.*")])
45
+ if file_path:
46
+ vocal_file_var.set(file_path)
47
+
48
+ # feathering
49
+ def validate_frame_preview(P):
50
+ if P == "":
51
+ return True # Allow empty input
52
+ try:
53
+ num = float(P)
54
+ if (num.is_integer()):
55
+ return True
56
+ except ValueError:
57
+ pass
58
+ return False
59
+
60
+ def start_syncking_kong():
61
+ print("Saving config")
62
+ config["OPTIONS"]["video_file"] = str(video_file_var.get())
63
+ config["OPTIONS"]["vocal_file"] = str(vocal_file_var.get())
64
+ config["OPTIONS"]["quality"] = str(quality_var.get())
65
+ config["OPTIONS"]["output_height"] = str(output_height_combobox.get())
66
+ config["OPTIONS"]["wav2lip_version"] = str(wav2lip_version_var.get())
67
+ config["OPTIONS"]["use_previous_tracking_data"] = str(use_previous_tracking_data_var.get())
68
+ config["OPTIONS"]["nosmooth"] = str(nosmooth_var.get())
69
+ config["OPTIONS"]["preview_window"] = str(preview_window_var.get())
70
+ config["PADDING"]["u"] = str(padding_vars["u"].get())
71
+ config["PADDING"]["d"] = str(padding_vars["d"].get())
72
+ config["PADDING"]["l"] = str(padding_vars["l"].get())
73
+ config["PADDING"]["r"] = str(padding_vars["r"].get())
74
+ config["MASK"]["size"] = str(size_var.get())
75
+ config["MASK"]["feathering"] = str(feathering_var.get())
76
+ config["MASK"]["mouth_tracking"] = str(mouth_tracking_var.get())
77
+ config["MASK"]["debug_mask"] = str(debug_mask_var.get())
78
+ config["OTHER"]["batch_process"] = str(batch_process_var.get())
79
+ config["OTHER"]["output_suffix"] = str(output_suffix_var.get())
80
+ config["OTHER"]["include_settings_in_suffix"] = str(include_settings_in_suffix_var.get())
81
+ config["OTHER"]["preview_settings"] = str(preview_settings_var.get())
82
+ config["OTHER"]["frame_to_preview"] = str(frame_to_preview_var.get())
83
+ save_config(config) # Save the updated config
84
+ with open("run.txt", "w") as f:
85
+ f.write("run")
86
+ exit()
87
+ # Add your logic here
88
+
89
+ root = tk.Tk()
90
+ root.title("SyncKing-Kong GUI")
91
+ root.geometry("800x720")
92
+ root.configure(bg="lightblue")
93
+
94
+ # Read the existing config.ini
95
+ config = read_config()
96
+
97
+ row=0
98
+ tk.Label(root, text=version, bg="lightblue").grid(row=row, column=0, sticky="w")
99
+ # Create a label for video file
100
+ row+=1
101
+ video_label = tk.Label(root, text="Video File Path:", bg="lightblue")
102
+ video_label.grid(row=row, column=0, sticky="e")
103
+
104
+ # Entry widget for video file path
105
+ video_file_var = tk.StringVar()
106
+ video_entry = tk.Entry(root, textvariable=video_file_var, width=80)
107
+ video_entry.grid(row=row, column=1, sticky="w")
108
+
109
+ # Create a button to open the file dialog
110
+ select_button = tk.Button(root, text="...", command=open_video_file)
111
+ select_button.grid(row=row, column=1, sticky="w", padx=490)
112
+
113
+ # Set the default value based on the existing config
114
+ video_file_var.set(config["OPTIONS"].get("video_file", ""))
115
+
116
+ row+=1
117
+ tk.Label(root, text="", bg="lightblue").grid(row=row, column=0, sticky="w")
118
+
119
+ # String input for vocal_file
120
+ row+=1
121
+
122
+ # Create a label for the input box
123
+ vocal_file_label = tk.Label(root, text="Vocal File Path:", bg="lightblue")
124
+ vocal_file_label.grid(row=row, column=0, sticky="e")
125
+
126
+ # Create an input box for the vocal file path
127
+ vocal_file_var = tk.StringVar()
128
+ vocal_file_entry = tk.Entry(root, textvariable=vocal_file_var, width=80)
129
+ vocal_file_entry.grid(row=row, column=1, sticky="w")
130
+
131
+ # Create a button to open the file dialog
132
+ select_button = tk.Button(root, text="...", command=open_vocal_file)
133
+ select_button.grid(row=row, column=1, sticky="w", padx=490)
134
+
135
+ # Set the initial value from the 'config' dictionary (if available)
136
+ vocal_file_var.set(config["OPTIONS"].get("vocal_file", ""))
137
+
138
+ row+=1
139
+ tk.Label(root, text="", bg="lightblue").grid(row=row, column=0, sticky="w")
140
+
141
+ # Dropdown box for quality options
142
+ row+=1
143
+ quality_label = tk.Label(root, text="Select Quality:", bg="lightblue")
144
+ quality_label.grid(row=row, column=0, sticky="e")
145
+ quality_options = ["Fast", "Improved", "Enhanced"]
146
+ quality_var = tk.StringVar()
147
+ quality_var.set(config["OPTIONS"].get("quality", "Improved"))
148
+ quality_dropdown = tk.OptionMenu(root, quality_var, *quality_options)
149
+ quality_dropdown.grid(row=row, column=1, sticky="w")
150
+
151
+ row+=1
152
+ tk.Label(root, text="", bg="lightblue").grid(row=row, column=0, sticky="w")
153
+
154
+ # Output height
155
+ row+=1
156
+ output_height_label = tk.Label(root, text="Output height:", bg="lightblue")
157
+ output_height_label.grid(row=row, column=0, sticky="e")
158
+ output_height_options = ["half resolution", "full resolution"]
159
+ output_height_combobox = ttk.Combobox(root, values=output_height_options)
160
+ output_height_combobox.set(config["OPTIONS"].get("output_height", "full resolution")) # Set default value
161
+ output_height_combobox.grid(row=row, column=1, sticky="w")
162
+
163
+ row+=1
164
+ tk.Label(root, text="", bg="lightblue").grid(row=row, column=0, sticky="w")
165
+
166
+ # Dropdown box for wav2lip version options
167
+ row+=1
168
+ wav2lip_version_label = tk.Label(root, text="Select Wav2Lip version:", bg="lightblue")
169
+ wav2lip_version_label.grid(row=row, column=0, sticky="e")
170
+ wav2lip_version_options = ["Wav2Lip", "Wav2Lip_GAN"]
171
+ wav2lip_version_var = tk.StringVar()
172
+ wav2lip_version_var.set(config["OPTIONS"].get("wav2lip_version", "Wav2Lip"))
173
+ wav2lip_version_dropdown = tk.OptionMenu(root, wav2lip_version_var, *wav2lip_version_options)
174
+ wav2lip_version_dropdown.grid(row=row, column=1, sticky="w")
175
+
176
+ row+=1
177
+ tk.Label(root, text="", bg="lightblue").grid(row=row, column=0, sticky="w")
178
+ # output_suffix
179
+ row+=1
180
+ output_suffix_label = tk.Label(root, text="Output File Suffix:", bg="lightblue")
181
+ output_suffix_label.grid(row=row, column=0, sticky="e")
182
+ output_suffix_var = tk.StringVar()
183
+ output_suffix_var.set(config["OTHER"].get("output_suffix", "_SyncKing-Kong"))
184
+ output_suffix_entry = output_suffix_entry = tk.Entry(root, textvariable=output_suffix_var, width=20)
185
+ output_suffix_entry.grid(row=row, column=1, sticky="w")
186
+
187
+ include_settings_in_suffix_var = tk.BooleanVar()
188
+ include_settings_in_suffix_var.set(config["OTHER"].get("include_settings_in_suffix", True)) # Set default value
189
+ include_settings_in_suffix_checkbox = tk.Checkbutton(root, text="Add Settings to Suffix", variable=include_settings_in_suffix_var, bg="lightblue")
190
+ include_settings_in_suffix_checkbox.grid(row=row, column=1, sticky="w", padx=130)
191
+
192
+ # batch_process
193
+ row+=1
194
+ tk.Label(root, text="", bg="lightblue").grid(row=row, column=0, sticky="w")
195
+ row+=1
196
+ batch_process_label = tk.Label(root, text="Batch Process:", bg="lightblue")
197
+ batch_process_label.grid(row=row, column=0, sticky="e")
198
+ batch_process_var = tk.BooleanVar()
199
+ batch_process_var.set(config["OTHER"].get("batch_process", True)) # Set default value
200
+ batch_process_checkbox = tk.Checkbutton(root, text="", variable=batch_process_var, bg="lightblue")
201
+ batch_process_checkbox.grid(row=row, column=1, sticky="w")
202
+
203
+ # Dropdown box for preview window options
204
+ row+=1
205
+ preview_window_label = tk.Label(root, text="Preview Window:", bg="lightblue")
206
+ preview_window_label.grid(row=row, column=0, sticky="e")
207
+ preview_window_options = ["Face", "Full", "Both", "None"]
208
+ preview_window_var = tk.StringVar()
209
+ preview_window_var.set(config["OPTIONS"].get("preview_window", "Face"))
210
+ preview_window_dropdown = tk.OptionMenu(root, preview_window_var, *preview_window_options)
211
+ preview_window_dropdown.grid(row=row, column=1, sticky="w")
212
+
213
+ row+=1
214
+ tk.Label(root, text="", bg="lightblue").grid(row=row, column=0, sticky="w")
215
+
216
+ # Button to start SyncKing-Kong
217
+ row+=1
218
+ start_button = tk.Button(root, text="Start SyncKing-Kong", command=start_syncking_kong, bg="#5af269", font=("Arial", 16))
219
+ start_button.grid(row=row, column=0, sticky="w", padx=290, columnspan=2)
220
+
221
+ row+=1
222
+ tk.Label(root, text="", bg="lightblue").grid(row=row, column=0, sticky="w")
223
+ tk.Label(root, text="", bg="lightblue").grid(row=row, column=0, sticky="w")
224
+
225
+ row+=1
226
+ tk.Label(root, text="Advanced Tweaking:", bg="lightblue", font=("Arial", 16)).grid(row=row, column=0, sticky="w")
227
+ row+=1
228
+ # Create a label with a custom cursor
229
+ link = tk.Label(root, text="(Click here to see readme)", bg="lightblue", fg="blue", font=("Arial", 10), cursor="hand2")
230
+ link.grid(row=row, column=0)
231
+
232
+ # Bind the click event to the label
233
+ link.bind("<Button-1>", open_github_link)
234
+
235
+ # Process one frame only
236
+ preview_settings_var = tk.BooleanVar()
237
+ preview_settings_var.set(config["OTHER"].get("preview_settings", True)) # Set default value
238
+ preview_settings_checkbox = tk.Checkbutton(root, text="Process one frame only - Frame to process:", variable=preview_settings_var, bg="lightblue")
239
+ preview_settings_checkbox.grid(row=row, column=1, sticky="w")
240
+
241
+ frame_to_preview_var = tk.StringVar()
242
+ frame_to_preview_entry = tk.Entry(root, textvariable=frame_to_preview_var, validate="key", width=3, validatecommand=(root.register(validate_frame_preview), "%P"))
243
+ frame_to_preview_entry.grid(row=row, column=1, sticky="w", padx=255)
244
+ frame_to_preview_var.set(config["OTHER"].get("frame_to_preview", "100"))
245
+
246
+ # Checkbox for nosmooth option
247
+ row+=1
248
+ nosmooth_var = tk.BooleanVar()
249
+ nosmooth_var.set(config["OPTIONS"].get("nosmooth", True)) # Set default value
250
+ nosmooth_checkbox = tk.Checkbutton(root, text="nosmooth - unticking will smooth face detection between 5 frames", variable=nosmooth_var, bg="lightblue")
251
+ nosmooth_checkbox.grid(row=row, column=1, sticky="w")
252
+
253
+ # Checkbox for use_previous_tracking_data option
254
+ row+=1
255
+ use_previous_tracking_data_var = tk.BooleanVar()
256
+ use_previous_tracking_data_var.set(config["OPTIONS"].get("use_previous_tracking_data", True)) # Set default value
257
+ use_previous_tracking_data_checkbox = tk.Checkbutton(root, text="Keep previous face tracking data if using same video", variable=use_previous_tracking_data_var, bg="lightblue")
258
+ use_previous_tracking_data_checkbox.grid(row=row, column=1, sticky="w")
259
+
260
+ # padding
261
+ row+=1
262
+ tk.Label(root, text="Padding:", bg="lightblue", font=("Arial", 12)).grid(row=row, column=1, sticky="sw", pady=10)
263
+ row+=1
264
+ tk.Label(root, text="(Up, Down, Left, Right)", bg="lightblue").grid(row=row, column=1, rowspan=4, sticky="w", padx=100)
265
+ padding_vars = {}
266
+
267
+ # Create a list of padding labels and their corresponding keys
268
+ padding_labels = [("U:", "u"), ("D:", "d"), ("L:", "l"), ("R:", "r")]
269
+
270
+ # Validation function to allow only integers
271
+ def validate_integer(P):
272
+ if P == "" or P == "-" or P.lstrip("-").isdigit():
273
+ return True
274
+ return False
275
+
276
+ # Create the padding labels and entry widgets using a loop
277
+ for label_text, key in padding_labels:
278
+ label = tk.Label(root, text=label_text, bg="lightblue")
279
+ label.grid(row=row, column=1, sticky="w", padx=50)
280
+
281
+ # Create a StringVar for the current key
282
+ padding_var = tk.StringVar()
283
+
284
+ # Set validation to allow positive and negative integers
285
+ entry = tk.Entry(root, textvariable=padding_var, width=3, validate="key", validatecommand=(root.register(validate_integer), "%P"))
286
+ entry.grid(row=row, column=1, sticky="w", padx=70)
287
+
288
+ # Set the default value from the 'config' dictionary
289
+ padding_var.set(config["PADDING"].get(key, ""))
290
+
291
+ # Store the StringVar in the dictionary
292
+ padding_vars[key] = padding_var
293
+
294
+ # Increment the row
295
+ row += 1
296
+
297
+
298
+ tk.Label(root, text="", bg="lightblue").grid(row=row, column=0, sticky="w")
299
+ row+=1
300
+ # mask size
301
+ def validate_custom_number(P):
302
+ if P == "":
303
+ return True # Allow empty input
304
+ try:
305
+ num = float(P)
306
+ if 0 <= num <= 6 and (num.is_integer() or (num * 10) % 1 == 0):
307
+ return True
308
+ except ValueError:
309
+ pass
310
+ return False
311
+
312
+ row+=1
313
+ tk.Label(root, text="Mask settings:", bg="lightblue", font=("Arial", 12)).grid(row=row, column=1, sticky="sw")
314
+ row+=1
315
+ size_label = tk.Label(root, text="Mask size:", bg="lightblue", padx=50)
316
+ size_label.grid(row=row, column=1, sticky="w")
317
+ size_var = tk.StringVar()
318
+ size_entry = tk.Entry(root, textvariable=size_var, validate="key", width=3, validatecommand=(root.register(validate_custom_number), "%P"))
319
+ size_entry.grid(row=row, column=1, sticky="w", padx=120)
320
+ size_var.set(config["MASK"].get("size", "2.5"))
321
+
322
+ # feathering
323
+ def validate_feather(P):
324
+ if P == "":
325
+ return True # Allow empty input
326
+ try:
327
+ num = float(P)
328
+ if 0 <= num <= 3 and (num.is_integer()):
329
+ return True
330
+ except ValueError:
331
+ pass
332
+ return False
333
+
334
+ row+=1
335
+ feathering_label = tk.Label(root, text="Feathering:", bg="lightblue", padx=50)
336
+ feathering_label.grid(row=row, column=1, sticky="w")
337
+ feathering_var = tk.StringVar()
338
+ feathering_entry = tk.Entry(root, textvariable=feathering_var, validate="key", width=3, validatecommand=(root.register(validate_feather), "%P"))
339
+ feathering_entry.grid(row=row, column=1, sticky="w", padx=120)
340
+ feathering_var.set(config["MASK"].get("feathering", "2.5"))
341
+
342
+ # mouth_tracking
343
+ row+=1
344
+ mouth_tracking_var = tk.BooleanVar()
345
+ mouth_tracking_var.set(config["MASK"].get("mouth_tracking", True)) # Set default value
346
+ mouth_tracking_checkbox = tk.Checkbutton(root, text="track mouth for mask on every frame", variable=mouth_tracking_var, bg="lightblue", padx=50)
347
+ mouth_tracking_checkbox.grid(row=row, column=1, sticky="w")
348
+
349
+ # debug_mask
350
+ row+=1
351
+ debug_mask_var = tk.BooleanVar()
352
+ debug_mask_var.set(config["MASK"].get("debug_mask", True)) # Set default value
353
+ debug_mask_checkbox = tk.Checkbutton(root, text="highlight mask for debugging", variable=debug_mask_var, bg="lightblue", padx=50)
354
+ debug_mask_checkbox.grid(row=row, column=1, sticky="w")
355
+
356
+ # Increase spacing between all rows (uniformly)
357
+ for row in range(row):
358
+ root.rowconfigure(row, weight=1)
359
+
360
+
361
+ root.mainloop()
README.md CHANGED
@@ -1,12 +1 @@
1
- ---
2
- title: Lipsing
3
- emoji: 🐠
4
- colorFrom: yellow
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 5.5.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ This is readme file
 
 
 
 
 
 
 
 
 
 
 
audio.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import librosa.filters
3
+ import numpy as np
4
+
5
+ # import tensorflow as tf
6
+ from scipy import signal
7
+ from scipy.io import wavfile
8
+ from hparams import hparams as hp
9
+
10
+
11
+ def load_wav(path, sr):
12
+ return librosa.core.load(path, sr=sr)[0]
13
+
14
+
15
+ def save_wav(wav, path, sr):
16
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
17
+ # proposed by @dsmiller
18
+ wavfile.write(path, sr, wav.astype(np.int16))
19
+
20
+
21
+ def save_wavenet_wav(wav, path, sr):
22
+ librosa.output.write_wav(path, wav, sr=sr)
23
+
24
+
25
+ def preemphasis(wav, k, preemphasize=True):
26
+ if preemphasize:
27
+ return signal.lfilter([1, -k], [1], wav)
28
+ return wav
29
+
30
+
31
+ def inv_preemphasis(wav, k, inv_preemphasize=True):
32
+ if inv_preemphasize:
33
+ return signal.lfilter([1], [1, -k], wav)
34
+ return wav
35
+
36
+
37
+ def get_hop_size():
38
+ hop_size = hp.hop_size
39
+ if hop_size is None:
40
+ assert hp.frame_shift_ms is not None
41
+ hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
42
+ return hop_size
43
+
44
+
45
+ def linearspectrogram(wav):
46
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
47
+ S = _amp_to_db(np.abs(D)) - hp.ref_level_db
48
+
49
+ if hp.signal_normalization:
50
+ return _normalize(S)
51
+ return S
52
+
53
+
54
+ def melspectrogram(wav):
55
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
56
+ S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
57
+
58
+ if hp.signal_normalization:
59
+ return _normalize(S)
60
+ return S
61
+
62
+
63
+ def _lws_processor():
64
+ import lws
65
+
66
+ return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
67
+
68
+
69
+ def _stft(y):
70
+ if hp.use_lws:
71
+ return _lws_processor(hp).stft(y).T
72
+ else:
73
+ return librosa.stft(
74
+ y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size
75
+ )
76
+
77
+
78
+ ##########################################################
79
+ # Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
80
+ def num_frames(length, fsize, fshift):
81
+ """Compute number of time frames of spectrogram"""
82
+ pad = fsize - fshift
83
+ if length % fshift == 0:
84
+ M = (length + pad * 2 - fsize) // fshift + 1
85
+ else:
86
+ M = (length + pad * 2 - fsize) // fshift + 2
87
+ return M
88
+
89
+
90
+ def pad_lr(x, fsize, fshift):
91
+ """Compute left and right padding"""
92
+ M = num_frames(len(x), fsize, fshift)
93
+ pad = fsize - fshift
94
+ T = len(x) + 2 * pad
95
+ r = (M - 1) * fshift + fsize - T
96
+ return pad, pad + r
97
+
98
+
99
+ ##########################################################
100
+ # Librosa correct padding
101
+ def librosa_pad_lr(x, fsize, fshift):
102
+ return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
103
+
104
+
105
+ # Conversions
106
+ _mel_basis = None
107
+
108
+
109
+ def _linear_to_mel(spectogram):
110
+ global _mel_basis
111
+ if _mel_basis is None:
112
+ _mel_basis = _build_mel_basis()
113
+ return np.dot(_mel_basis, spectogram)
114
+
115
+
116
+ def _build_mel_basis():
117
+ assert hp.fmax <= hp.sample_rate // 2
118
+ return librosa.filters.mel(
119
+ sr=hp.sample_rate,
120
+ n_fft=hp.n_fft,
121
+ n_mels=hp.num_mels,
122
+ fmin=hp.fmin,
123
+ fmax=hp.fmax,
124
+ )
125
+
126
+
127
+ def _amp_to_db(x):
128
+ min_level = np.exp(hp.min_level_db / 20 * np.log(10))
129
+ return 20 * np.log10(np.maximum(min_level, x))
130
+
131
+
132
+ def _db_to_amp(x):
133
+ return np.power(10.0, (x) * 0.05)
134
+
135
+
136
+ def _normalize(S):
137
+ if hp.allow_clipping_in_normalization:
138
+ if hp.symmetric_mels:
139
+ return np.clip(
140
+ (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db))
141
+ - hp.max_abs_value,
142
+ -hp.max_abs_value,
143
+ hp.max_abs_value,
144
+ )
145
+ else:
146
+ return np.clip(
147
+ hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)),
148
+ 0,
149
+ hp.max_abs_value,
150
+ )
151
+
152
+ assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
153
+ if hp.symmetric_mels:
154
+ return (2 * hp.max_abs_value) * (
155
+ (S - hp.min_level_db) / (-hp.min_level_db)
156
+ ) - hp.max_abs_value
157
+ else:
158
+ return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
159
+
160
+
161
+ def _denormalize(D):
162
+ if hp.allow_clipping_in_normalization:
163
+ if hp.symmetric_mels:
164
+ return (
165
+ (np.clip(D, -hp.max_abs_value, hp.max_abs_value) + hp.max_abs_value)
166
+ * -hp.min_level_db
167
+ / (2 * hp.max_abs_value)
168
+ ) + hp.min_level_db
169
+ else:
170
+ return (
171
+ np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value
172
+ ) + hp.min_level_db
173
+
174
+ if hp.symmetric_mels:
175
+ return (
176
+ (D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)
177
+ ) + hp.min_level_db
178
+ else:
179
+ return (D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db
config.ini ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ [OPTIONS]
3
+
4
+ video_file =
5
+ vocal_file =
6
+
7
+ quality = Improved
8
+ # Options:
9
+ ; Fast: Wav2Lip only
10
+ ; Improved: Wav2Lip with a feathered mask around the mouth to remove the square around the face
11
+ ; Enhanced: Wav2Lip + mask + GFPGAN upscaling done on the face
12
+ ; Experimental: Test version of applying gfpgan - see release notes
13
+
14
+ output_height = full resolution
15
+
16
+ # Options:
17
+ ; full resolution
18
+ ; half resolution
19
+ ; video height in pixels eg: 480
20
+
21
+ wav2lip_version = Wav2Lip
22
+ # Wav2Lip or Wav2Lip_GAN
23
+
24
+ # Please consult the readme for this and the rest of the options:
25
+ ; https://github.com/anothermartz/Easy-Wav2Lip#advanced-tweaking
26
+
27
+ use_previous_tracking_data = True
28
+
29
+ nosmooth = True
30
+
31
+ preview_window = Full
32
+
33
+ [PADDING]
34
+ u = 0
35
+ d = 0
36
+ l = 0
37
+ r = 0
38
+
39
+ [MASK]
40
+ size = 2.5
41
+ feathering = 2
42
+ mouth_tracking = False
43
+ debug_mask = False
44
+
45
+ [OTHER]
46
+ batch_process = False
47
+ output_suffix = _SyncKing-Kong
48
+ include_settings_in_suffix = False
49
+ preview_settings = False
50
+ frame_to_preview = 100
51
+
degradations.py ADDED
@@ -0,0 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import random
5
+ import torch
6
+ from scipy import special
7
+ from scipy.stats import multivariate_normal
8
+ from torchvision.transforms.functional import rgb_to_grayscale
9
+
10
+ # -------------------------------------------------------------------- #
11
+ # --------------------------- blur kernels --------------------------- #
12
+ # -------------------------------------------------------------------- #
13
+
14
+
15
+ # --------------------------- util functions --------------------------- #
16
+ def sigma_matrix2(sig_x, sig_y, theta):
17
+ """Calculate the rotated sigma matrix (two dimensional matrix).
18
+
19
+ Args:
20
+ sig_x (float):
21
+ sig_y (float):
22
+ theta (float): Radian measurement.
23
+
24
+ Returns:
25
+ ndarray: Rotated sigma matrix.
26
+ """
27
+ d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
28
+ u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
29
+ return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
30
+
31
+
32
+ def mesh_grid(kernel_size):
33
+ """Generate the mesh grid, centering at zero.
34
+
35
+ Args:
36
+ kernel_size (int):
37
+
38
+ Returns:
39
+ xy (ndarray): with the shape (kernel_size, kernel_size, 2)
40
+ xx (ndarray): with the shape (kernel_size, kernel_size)
41
+ yy (ndarray): with the shape (kernel_size, kernel_size)
42
+ """
43
+ ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
44
+ xx, yy = np.meshgrid(ax, ax)
45
+ xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
46
+ 1))).reshape(kernel_size, kernel_size, 2)
47
+ return xy, xx, yy
48
+
49
+
50
+ def pdf2(sigma_matrix, grid):
51
+ """Calculate PDF of the bivariate Gaussian distribution.
52
+
53
+ Args:
54
+ sigma_matrix (ndarray): with the shape (2, 2)
55
+ grid (ndarray): generated by :func:`mesh_grid`,
56
+ with the shape (K, K, 2), K is the kernel size.
57
+
58
+ Returns:
59
+ kernel (ndarrray): un-normalized kernel.
60
+ """
61
+ inverse_sigma = np.linalg.inv(sigma_matrix)
62
+ kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
63
+ return kernel
64
+
65
+
66
+ def cdf2(d_matrix, grid):
67
+ """Calculate the CDF of the standard bivariate Gaussian distribution.
68
+ Used in skewed Gaussian distribution.
69
+
70
+ Args:
71
+ d_matrix (ndarrasy): skew matrix.
72
+ grid (ndarray): generated by :func:`mesh_grid`,
73
+ with the shape (K, K, 2), K is the kernel size.
74
+
75
+ Returns:
76
+ cdf (ndarray): skewed cdf.
77
+ """
78
+ rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
79
+ grid = np.dot(grid, d_matrix)
80
+ cdf = rv.cdf(grid)
81
+ return cdf
82
+
83
+
84
+ def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
85
+ """Generate a bivariate isotropic or anisotropic Gaussian kernel.
86
+
87
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
88
+
89
+ Args:
90
+ kernel_size (int):
91
+ sig_x (float):
92
+ sig_y (float):
93
+ theta (float): Radian measurement.
94
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
95
+ with the shape (K, K, 2), K is the kernel size. Default: None
96
+ isotropic (bool):
97
+
98
+ Returns:
99
+ kernel (ndarray): normalized kernel.
100
+ """
101
+ if grid is None:
102
+ grid, _, _ = mesh_grid(kernel_size)
103
+ if isotropic:
104
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
105
+ else:
106
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
107
+ kernel = pdf2(sigma_matrix, grid)
108
+ kernel = kernel / np.sum(kernel)
109
+ return kernel
110
+
111
+
112
+ def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
113
+ """Generate a bivariate generalized Gaussian kernel.
114
+
115
+ ``Paper: Parameter Estimation For Multivariate Generalized Gaussian Distributions``
116
+
117
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
118
+
119
+ Args:
120
+ kernel_size (int):
121
+ sig_x (float):
122
+ sig_y (float):
123
+ theta (float): Radian measurement.
124
+ beta (float): shape parameter, beta = 1 is the normal distribution.
125
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
126
+ with the shape (K, K, 2), K is the kernel size. Default: None
127
+
128
+ Returns:
129
+ kernel (ndarray): normalized kernel.
130
+ """
131
+ if grid is None:
132
+ grid, _, _ = mesh_grid(kernel_size)
133
+ if isotropic:
134
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
135
+ else:
136
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
137
+ inverse_sigma = np.linalg.inv(sigma_matrix)
138
+ kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
139
+ kernel = kernel / np.sum(kernel)
140
+ return kernel
141
+
142
+
143
+ def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
144
+ """Generate a plateau-like anisotropic kernel.
145
+
146
+ 1 / (1+x^(beta))
147
+
148
+ Reference: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution
149
+
150
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
151
+
152
+ Args:
153
+ kernel_size (int):
154
+ sig_x (float):
155
+ sig_y (float):
156
+ theta (float): Radian measurement.
157
+ beta (float): shape parameter, beta = 1 is the normal distribution.
158
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
159
+ with the shape (K, K, 2), K is the kernel size. Default: None
160
+
161
+ Returns:
162
+ kernel (ndarray): normalized kernel.
163
+ """
164
+ if grid is None:
165
+ grid, _, _ = mesh_grid(kernel_size)
166
+ if isotropic:
167
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
168
+ else:
169
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
170
+ inverse_sigma = np.linalg.inv(sigma_matrix)
171
+ kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
172
+ kernel = kernel / np.sum(kernel)
173
+ return kernel
174
+
175
+
176
+ def random_bivariate_Gaussian(kernel_size,
177
+ sigma_x_range,
178
+ sigma_y_range,
179
+ rotation_range,
180
+ noise_range=None,
181
+ isotropic=True):
182
+ """Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
183
+
184
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
185
+
186
+ Args:
187
+ kernel_size (int):
188
+ sigma_x_range (tuple): [0.6, 5]
189
+ sigma_y_range (tuple): [0.6, 5]
190
+ rotation range (tuple): [-math.pi, math.pi]
191
+ noise_range(tuple, optional): multiplicative kernel noise,
192
+ [0.75, 1.25]. Default: None
193
+
194
+ Returns:
195
+ kernel (ndarray):
196
+ """
197
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
198
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
199
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
200
+ if isotropic is False:
201
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
202
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
203
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
204
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
205
+ else:
206
+ sigma_y = sigma_x
207
+ rotation = 0
208
+
209
+ kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic)
210
+
211
+ # add multiplicative noise
212
+ if noise_range is not None:
213
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
214
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
215
+ kernel = kernel * noise
216
+ kernel = kernel / np.sum(kernel)
217
+ return kernel
218
+
219
+
220
+ def random_bivariate_generalized_Gaussian(kernel_size,
221
+ sigma_x_range,
222
+ sigma_y_range,
223
+ rotation_range,
224
+ beta_range,
225
+ noise_range=None,
226
+ isotropic=True):
227
+ """Randomly generate bivariate generalized Gaussian kernels.
228
+
229
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
230
+
231
+ Args:
232
+ kernel_size (int):
233
+ sigma_x_range (tuple): [0.6, 5]
234
+ sigma_y_range (tuple): [0.6, 5]
235
+ rotation range (tuple): [-math.pi, math.pi]
236
+ beta_range (tuple): [0.5, 8]
237
+ noise_range(tuple, optional): multiplicative kernel noise,
238
+ [0.75, 1.25]. Default: None
239
+
240
+ Returns:
241
+ kernel (ndarray):
242
+ """
243
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
244
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
245
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
246
+ if isotropic is False:
247
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
248
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
249
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
250
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
251
+ else:
252
+ sigma_y = sigma_x
253
+ rotation = 0
254
+
255
+ # assume beta_range[0] < 1 < beta_range[1]
256
+ if np.random.uniform() < 0.5:
257
+ beta = np.random.uniform(beta_range[0], 1)
258
+ else:
259
+ beta = np.random.uniform(1, beta_range[1])
260
+
261
+ kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
262
+
263
+ # add multiplicative noise
264
+ if noise_range is not None:
265
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
266
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
267
+ kernel = kernel * noise
268
+ kernel = kernel / np.sum(kernel)
269
+ return kernel
270
+
271
+
272
+ def random_bivariate_plateau(kernel_size,
273
+ sigma_x_range,
274
+ sigma_y_range,
275
+ rotation_range,
276
+ beta_range,
277
+ noise_range=None,
278
+ isotropic=True):
279
+ """Randomly generate bivariate plateau kernels.
280
+
281
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
282
+
283
+ Args:
284
+ kernel_size (int):
285
+ sigma_x_range (tuple): [0.6, 5]
286
+ sigma_y_range (tuple): [0.6, 5]
287
+ rotation range (tuple): [-math.pi/2, math.pi/2]
288
+ beta_range (tuple): [1, 4]
289
+ noise_range(tuple, optional): multiplicative kernel noise,
290
+ [0.75, 1.25]. Default: None
291
+
292
+ Returns:
293
+ kernel (ndarray):
294
+ """
295
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
296
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
297
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
298
+ if isotropic is False:
299
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
300
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
301
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
302
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
303
+ else:
304
+ sigma_y = sigma_x
305
+ rotation = 0
306
+
307
+ # TODO: this may be not proper
308
+ if np.random.uniform() < 0.5:
309
+ beta = np.random.uniform(beta_range[0], 1)
310
+ else:
311
+ beta = np.random.uniform(1, beta_range[1])
312
+
313
+ kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
314
+ # add multiplicative noise
315
+ if noise_range is not None:
316
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
317
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
318
+ kernel = kernel * noise
319
+ kernel = kernel / np.sum(kernel)
320
+
321
+ return kernel
322
+
323
+
324
+ def random_mixed_kernels(kernel_list,
325
+ kernel_prob,
326
+ kernel_size=21,
327
+ sigma_x_range=(0.6, 5),
328
+ sigma_y_range=(0.6, 5),
329
+ rotation_range=(-math.pi, math.pi),
330
+ betag_range=(0.5, 8),
331
+ betap_range=(0.5, 8),
332
+ noise_range=None):
333
+ """Randomly generate mixed kernels.
334
+
335
+ Args:
336
+ kernel_list (tuple): a list name of kernel types,
337
+ support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
338
+ 'plateau_aniso']
339
+ kernel_prob (tuple): corresponding kernel probability for each
340
+ kernel type
341
+ kernel_size (int):
342
+ sigma_x_range (tuple): [0.6, 5]
343
+ sigma_y_range (tuple): [0.6, 5]
344
+ rotation range (tuple): [-math.pi, math.pi]
345
+ beta_range (tuple): [0.5, 8]
346
+ noise_range(tuple, optional): multiplicative kernel noise,
347
+ [0.75, 1.25]. Default: None
348
+
349
+ Returns:
350
+ kernel (ndarray):
351
+ """
352
+ kernel_type = random.choices(kernel_list, kernel_prob)[0]
353
+ if kernel_type == 'iso':
354
+ kernel = random_bivariate_Gaussian(
355
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True)
356
+ elif kernel_type == 'aniso':
357
+ kernel = random_bivariate_Gaussian(
358
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False)
359
+ elif kernel_type == 'generalized_iso':
360
+ kernel = random_bivariate_generalized_Gaussian(
361
+ kernel_size,
362
+ sigma_x_range,
363
+ sigma_y_range,
364
+ rotation_range,
365
+ betag_range,
366
+ noise_range=noise_range,
367
+ isotropic=True)
368
+ elif kernel_type == 'generalized_aniso':
369
+ kernel = random_bivariate_generalized_Gaussian(
370
+ kernel_size,
371
+ sigma_x_range,
372
+ sigma_y_range,
373
+ rotation_range,
374
+ betag_range,
375
+ noise_range=noise_range,
376
+ isotropic=False)
377
+ elif kernel_type == 'plateau_iso':
378
+ kernel = random_bivariate_plateau(
379
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True)
380
+ elif kernel_type == 'plateau_aniso':
381
+ kernel = random_bivariate_plateau(
382
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False)
383
+ return kernel
384
+
385
+
386
+ np.seterr(divide='ignore', invalid='ignore')
387
+
388
+
389
+ def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
390
+ """2D sinc filter
391
+
392
+ Reference: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
393
+
394
+ Args:
395
+ cutoff (float): cutoff frequency in radians (pi is max)
396
+ kernel_size (int): horizontal and vertical size, must be odd.
397
+ pad_to (int): pad kernel size to desired size, must be odd or zero.
398
+ """
399
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
400
+ kernel = np.fromfunction(
401
+ lambda x, y: cutoff * special.j1(cutoff * np.sqrt(
402
+ (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt(
403
+ (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size])
404
+ kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi)
405
+ kernel = kernel / np.sum(kernel)
406
+ if pad_to > kernel_size:
407
+ pad_size = (pad_to - kernel_size) // 2
408
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
409
+ return kernel
410
+
411
+
412
+ # ------------------------------------------------------------- #
413
+ # --------------------------- noise --------------------------- #
414
+ # ------------------------------------------------------------- #
415
+
416
+ # ----------------------- Gaussian Noise ----------------------- #
417
+
418
+
419
+ def generate_gaussian_noise(img, sigma=10, gray_noise=False):
420
+ """Generate Gaussian noise.
421
+
422
+ Args:
423
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
424
+ sigma (float): Noise scale (measured in range 255). Default: 10.
425
+
426
+ Returns:
427
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
428
+ float32.
429
+ """
430
+ if gray_noise:
431
+ noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.
432
+ noise = np.expand_dims(noise, axis=2).repeat(3, axis=2)
433
+ else:
434
+ noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.
435
+ return noise
436
+
437
+
438
+ def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False):
439
+ """Add Gaussian noise.
440
+
441
+ Args:
442
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
443
+ sigma (float): Noise scale (measured in range 255). Default: 10.
444
+
445
+ Returns:
446
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
447
+ float32.
448
+ """
449
+ noise = generate_gaussian_noise(img, sigma, gray_noise)
450
+ out = img + noise
451
+ if clip and rounds:
452
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
453
+ elif clip:
454
+ out = np.clip(out, 0, 1)
455
+ elif rounds:
456
+ out = (out * 255.0).round() / 255.
457
+ return out
458
+
459
+
460
+ def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
461
+ """Add Gaussian noise (PyTorch version).
462
+
463
+ Args:
464
+ img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
465
+ scale (float | Tensor): Noise scale. Default: 1.0.
466
+
467
+ Returns:
468
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
469
+ float32.
470
+ """
471
+ b, _, h, w = img.size()
472
+ if not isinstance(sigma, (float, int)):
473
+ sigma = sigma.view(img.size(0), 1, 1, 1)
474
+ if isinstance(gray_noise, (float, int)):
475
+ cal_gray_noise = gray_noise > 0
476
+ else:
477
+ gray_noise = gray_noise.view(b, 1, 1, 1)
478
+ cal_gray_noise = torch.sum(gray_noise) > 0
479
+
480
+ if cal_gray_noise:
481
+ noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.
482
+ noise_gray = noise_gray.view(b, 1, h, w)
483
+
484
+ # always calculate color noise
485
+ noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.
486
+
487
+ if cal_gray_noise:
488
+ noise = noise * (1 - gray_noise) + noise_gray * gray_noise
489
+ return noise
490
+
491
+
492
+ def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False):
493
+ """Add Gaussian noise (PyTorch version).
494
+
495
+ Args:
496
+ img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
497
+ scale (float | Tensor): Noise scale. Default: 1.0.
498
+
499
+ Returns:
500
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
501
+ float32.
502
+ """
503
+ noise = generate_gaussian_noise_pt(img, sigma, gray_noise)
504
+ out = img + noise
505
+ if clip and rounds:
506
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
507
+ elif clip:
508
+ out = torch.clamp(out, 0, 1)
509
+ elif rounds:
510
+ out = (out * 255.0).round() / 255.
511
+ return out
512
+
513
+
514
+ # ----------------------- Random Gaussian Noise ----------------------- #
515
+ def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0):
516
+ sigma = np.random.uniform(sigma_range[0], sigma_range[1])
517
+ if np.random.uniform() < gray_prob:
518
+ gray_noise = True
519
+ else:
520
+ gray_noise = False
521
+ return generate_gaussian_noise(img, sigma, gray_noise)
522
+
523
+
524
+ def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
525
+ noise = random_generate_gaussian_noise(img, sigma_range, gray_prob)
526
+ out = img + noise
527
+ if clip and rounds:
528
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
529
+ elif clip:
530
+ out = np.clip(out, 0, 1)
531
+ elif rounds:
532
+ out = (out * 255.0).round() / 255.
533
+ return out
534
+
535
+
536
+ def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0):
537
+ sigma = torch.rand(
538
+ img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
539
+ gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
540
+ gray_noise = (gray_noise < gray_prob).float()
541
+ return generate_gaussian_noise_pt(img, sigma, gray_noise)
542
+
543
+
544
+ def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
545
+ noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob)
546
+ out = img + noise
547
+ if clip and rounds:
548
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
549
+ elif clip:
550
+ out = torch.clamp(out, 0, 1)
551
+ elif rounds:
552
+ out = (out * 255.0).round() / 255.
553
+ return out
554
+
555
+
556
+ # ----------------------- Poisson (Shot) Noise ----------------------- #
557
+
558
+
559
+ def generate_poisson_noise(img, scale=1.0, gray_noise=False):
560
+ """Generate poisson noise.
561
+
562
+ Reference: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219
563
+
564
+ Args:
565
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
566
+ scale (float): Noise scale. Default: 1.0.
567
+ gray_noise (bool): Whether generate gray noise. Default: False.
568
+
569
+ Returns:
570
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
571
+ float32.
572
+ """
573
+ if gray_noise:
574
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
575
+ # round and clip image for counting vals correctly
576
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
577
+ vals = len(np.unique(img))
578
+ vals = 2**np.ceil(np.log2(vals))
579
+ out = np.float32(np.random.poisson(img * vals) / float(vals))
580
+ noise = out - img
581
+ if gray_noise:
582
+ noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2)
583
+ return noise * scale
584
+
585
+
586
+ def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False):
587
+ """Add poisson noise.
588
+
589
+ Args:
590
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
591
+ scale (float): Noise scale. Default: 1.0.
592
+ gray_noise (bool): Whether generate gray noise. Default: False.
593
+
594
+ Returns:
595
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
596
+ float32.
597
+ """
598
+ noise = generate_poisson_noise(img, scale, gray_noise)
599
+ out = img + noise
600
+ if clip and rounds:
601
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
602
+ elif clip:
603
+ out = np.clip(out, 0, 1)
604
+ elif rounds:
605
+ out = (out * 255.0).round() / 255.
606
+ return out
607
+
608
+
609
+ def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
610
+ """Generate a batch of poisson noise (PyTorch version)
611
+
612
+ Args:
613
+ img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
614
+ scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
615
+ Default: 1.0.
616
+ gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
617
+ 0 for False, 1 for True. Default: 0.
618
+
619
+ Returns:
620
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
621
+ float32.
622
+ """
623
+ b, _, h, w = img.size()
624
+ if isinstance(gray_noise, (float, int)):
625
+ cal_gray_noise = gray_noise > 0
626
+ else:
627
+ gray_noise = gray_noise.view(b, 1, 1, 1)
628
+ cal_gray_noise = torch.sum(gray_noise) > 0
629
+ if cal_gray_noise:
630
+ img_gray = rgb_to_grayscale(img, num_output_channels=1)
631
+ # round and clip image for counting vals correctly
632
+ img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.
633
+ # use for-loop to get the unique values for each sample
634
+ vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
635
+ vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
636
+ vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
637
+ out = torch.poisson(img_gray * vals) / vals
638
+ noise_gray = out - img_gray
639
+ noise_gray = noise_gray.expand(b, 3, h, w)
640
+
641
+ # always calculate color noise
642
+ # round and clip image for counting vals correctly
643
+ img = torch.clamp((img * 255.0).round(), 0, 255) / 255.
644
+ # use for-loop to get the unique values for each sample
645
+ vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
646
+ vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
647
+ vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
648
+ out = torch.poisson(img * vals) / vals
649
+ noise = out - img
650
+ if cal_gray_noise:
651
+ noise = noise * (1 - gray_noise) + noise_gray * gray_noise
652
+ if not isinstance(scale, (float, int)):
653
+ scale = scale.view(b, 1, 1, 1)
654
+ return noise * scale
655
+
656
+
657
+ def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0):
658
+ """Add poisson noise to a batch of images (PyTorch version).
659
+
660
+ Args:
661
+ img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
662
+ scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
663
+ Default: 1.0.
664
+ gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
665
+ 0 for False, 1 for True. Default: 0.
666
+
667
+ Returns:
668
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
669
+ float32.
670
+ """
671
+ noise = generate_poisson_noise_pt(img, scale, gray_noise)
672
+ out = img + noise
673
+ if clip and rounds:
674
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
675
+ elif clip:
676
+ out = torch.clamp(out, 0, 1)
677
+ elif rounds:
678
+ out = (out * 255.0).round() / 255.
679
+ return out
680
+
681
+
682
+ # ----------------------- Random Poisson (Shot) Noise ----------------------- #
683
+
684
+
685
+ def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
686
+ scale = np.random.uniform(scale_range[0], scale_range[1])
687
+ if np.random.uniform() < gray_prob:
688
+ gray_noise = True
689
+ else:
690
+ gray_noise = False
691
+ return generate_poisson_noise(img, scale, gray_noise)
692
+
693
+
694
+ def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
695
+ noise = random_generate_poisson_noise(img, scale_range, gray_prob)
696
+ out = img + noise
697
+ if clip and rounds:
698
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
699
+ elif clip:
700
+ out = np.clip(out, 0, 1)
701
+ elif rounds:
702
+ out = (out * 255.0).round() / 255.
703
+ return out
704
+
705
+
706
+ def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0):
707
+ scale = torch.rand(
708
+ img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
709
+ gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
710
+ gray_noise = (gray_noise < gray_prob).float()
711
+ return generate_poisson_noise_pt(img, scale, gray_noise)
712
+
713
+
714
+ def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
715
+ noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob)
716
+ out = img + noise
717
+ if clip and rounds:
718
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
719
+ elif clip:
720
+ out = torch.clamp(out, 0, 1)
721
+ elif rounds:
722
+ out = (out * 255.0).round() / 255.
723
+ return out
724
+
725
+
726
+ # ------------------------------------------------------------------------ #
727
+ # --------------------------- JPEG compression --------------------------- #
728
+ # ------------------------------------------------------------------------ #
729
+
730
+
731
+ def add_jpg_compression(img, quality=90):
732
+ """Add JPG compression artifacts.
733
+
734
+ Args:
735
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
736
+ quality (float): JPG compression quality. 0 for lowest quality, 100 for
737
+ best quality. Default: 90.
738
+
739
+ Returns:
740
+ (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
741
+ float32.
742
+ """
743
+ img = np.clip(img, 0, 1)
744
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
745
+ _, encimg = cv2.imencode('.jpg', img * 255., encode_param)
746
+ img = np.float32(cv2.imdecode(encimg, 1)) / 255.
747
+ return img
748
+
749
+
750
+ def random_add_jpg_compression(img, quality_range=(90, 100)):
751
+ """Randomly add JPG compression artifacts.
752
+
753
+ Args:
754
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
755
+ quality_range (tuple[float] | list[float]): JPG compression quality
756
+ range. 0 for lowest quality, 100 for best quality.
757
+ Default: (90, 100).
758
+
759
+ Returns:
760
+ (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
761
+ float32.
762
+ """
763
+ quality = np.random.uniform(quality_range[0], quality_range[1])
764
+ return add_jpg_compression(img, quality)
easy_functions.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import subprocess
3
+ import json
4
+ import os
5
+ import dlib
6
+ import gdown
7
+ import pickle
8
+ import re
9
+ from models import Wav2Lip
10
+ from base64 import b64encode
11
+ from urllib.parse import urlparse
12
+ from torch.hub import download_url_to_file, get_dir
13
+ from IPython.display import HTML, display
14
+
15
+ device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
16
+
17
+
18
+ def get_video_details(filename):
19
+ cmd = [
20
+ "ffprobe",
21
+ "-v",
22
+ "error",
23
+ "-show_format",
24
+ "-show_streams",
25
+ "-of",
26
+ "json",
27
+ filename,
28
+ ]
29
+ result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
30
+ info = json.loads(result.stdout)
31
+
32
+ # Get video stream
33
+ video_stream = next(
34
+ stream for stream in info["streams"] if stream["codec_type"] == "video"
35
+ )
36
+
37
+ # Get resolution
38
+ width = int(video_stream["width"])
39
+ height = int(video_stream["height"])
40
+ resolution = width * height
41
+
42
+ # Get fps
43
+ fps = eval(video_stream["avg_frame_rate"])
44
+
45
+ # Get length
46
+ length = float(info["format"]["duration"])
47
+
48
+ return width, height, fps, length
49
+
50
+
51
+ def show_video(file_path):
52
+ """Function to display video in Colab"""
53
+ mp4 = open(file_path, "rb").read()
54
+ data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
55
+ width, _, _, _ = get_video_details(file_path)
56
+ display(
57
+ HTML(
58
+ """
59
+ <video controls width=%d>
60
+ <source src="%s" type="video/mp4">
61
+ </video>
62
+ """
63
+ % (min(width, 1280), data_url)
64
+ )
65
+ )
66
+
67
+
68
+ def format_time(seconds):
69
+ hours = int(seconds // 3600)
70
+ minutes = int((seconds % 3600) // 60)
71
+ seconds = int(seconds % 60)
72
+
73
+ if hours > 0:
74
+ return f"{hours}h {minutes}m {seconds}s"
75
+ elif minutes > 0:
76
+ return f"{minutes}m {seconds}s"
77
+ else:
78
+ return f"{seconds}s"
79
+
80
+
81
+ def _load(checkpoint_path):
82
+ if device != "cpu":
83
+ checkpoint = torch.load(checkpoint_path)
84
+ else:
85
+ checkpoint = torch.load(
86
+ checkpoint_path, map_location=lambda storage, loc: storage
87
+ )
88
+ return checkpoint
89
+
90
+
91
+ def load_model(path):
92
+ # If results file exists, load it and return
93
+ working_directory = os.getcwd()
94
+ folder, filename_with_extension = os.path.split(path)
95
+ filename, file_type = os.path.splitext(filename_with_extension)
96
+ results_file = os.path.join(folder, filename + ".pk1")
97
+ if os.path.exists(results_file):
98
+ with open(results_file, "rb") as f:
99
+ return pickle.load(f)
100
+ model = Wav2Lip()
101
+ print("Loading {}".format(path))
102
+ checkpoint = _load(path)
103
+ s = checkpoint["state_dict"]
104
+ new_s = {}
105
+ for k, v in s.items():
106
+ new_s[k.replace("module.", "")] = v
107
+ model.load_state_dict(new_s)
108
+
109
+ model = model.to(device)
110
+ # Save results to file
111
+ with open(results_file, "wb") as f:
112
+ pickle.dump(model.eval(), f)
113
+ # os.remove(path)
114
+ return model.eval()
115
+
116
+
117
+ def get_input_length(filename):
118
+ result = subprocess.run(
119
+ [
120
+ "ffprobe",
121
+ "-v",
122
+ "error",
123
+ "-show_entries",
124
+ "format=duration",
125
+ "-of",
126
+ "default=noprint_wrappers=1:nokey=1",
127
+ filename,
128
+ ],
129
+ stdout=subprocess.PIPE,
130
+ stderr=subprocess.STDOUT,
131
+ )
132
+ return float(result.stdout)
133
+
134
+
135
+ def is_url(string):
136
+ url_regex = re.compile(r"^(https?|ftp)://[^\s/$.?#].[^\s]*$")
137
+ return bool(url_regex.match(string))
138
+
139
+
140
+ def load_predictor():
141
+ checkpoint = os.path.join(
142
+ "checkpoints", "shape_predictor_68_face_landmarks_GTX.dat"
143
+ )
144
+ predictor = dlib.shape_predictor(checkpoint)
145
+ mouth_detector = dlib.get_frontal_face_detector()
146
+
147
+ # Serialize the variables
148
+ with open(os.path.join("checkpoints", "predictor.pkl"), "wb") as f:
149
+ pickle.dump(predictor, f)
150
+
151
+ with open(os.path.join("checkpoints", "mouth_detector.pkl"), "wb") as f:
152
+ pickle.dump(mouth_detector, f)
153
+
154
+ # delete the .dat file as it is no longer needed
155
+ # os.remove(output)
156
+
157
+
158
+ def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
159
+ """Load file form http url, will download models if necessary.
160
+
161
+ Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
162
+
163
+ Args:
164
+ url (str): URL to be downloaded.
165
+ model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
166
+ Default: None.
167
+ progress (bool): Whether to show the download progress. Default: True.
168
+ file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
169
+
170
+ Returns:
171
+ str: The path to the downloaded file.
172
+ """
173
+ if model_dir is None: # use the pytorch hub_dir
174
+ hub_dir = get_dir()
175
+ model_dir = os.path.join(hub_dir, "checkpoints")
176
+
177
+ os.makedirs(model_dir, exist_ok=True)
178
+
179
+ parts = urlparse(url)
180
+ filename = os.path.basename(parts.path)
181
+ if file_name is not None:
182
+ filename = file_name
183
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
184
+ if not os.path.exists(cached_file):
185
+ print(f'Downloading: "{url}" to {cached_file}\n')
186
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
187
+ return cached_file
188
+
189
+
190
+ def g_colab():
191
+ try:
192
+ import google.colab
193
+
194
+ return True
195
+ except ImportError:
196
+ return False
enhance.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from gfpgan import GFPGANer
3
+
4
+ warnings.filterwarnings("ignore")
5
+
6
+
7
+ def load_sr():
8
+ run_params = GFPGANer(
9
+ model_path="checkpoints/GFPGANv1.4.pth",
10
+ upscale=1,
11
+ arch="clean",
12
+ channel_multiplier=2,
13
+ bg_upsampler=None,
14
+ )
15
+ return run_params
16
+
17
+
18
+ def upscale(image, properties):
19
+ _, _, output = properties.enhance(
20
+ image, has_aligned=False, only_center_face=False, paste_back=True
21
+ )
22
+ return output
hparams.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from glob import glob
2
+ import os
3
+
4
+
5
+ def get_image_list(data_root, split):
6
+ filelist = []
7
+
8
+ with open("filelists/{}.txt".format(split)) as f:
9
+ for line in f:
10
+ line = line.strip()
11
+ if " " in line:
12
+ line = line.split()[0]
13
+ filelist.append(os.path.join(data_root, line))
14
+
15
+ return filelist
16
+
17
+
18
+ class HParams:
19
+ def __init__(self, **kwargs):
20
+ self.data = {}
21
+
22
+ for key, value in kwargs.items():
23
+ self.data[key] = value
24
+
25
+ def __getattr__(self, key):
26
+ if key not in self.data:
27
+ raise AttributeError("'HParams' object has no attribute %s" % key)
28
+ return self.data[key]
29
+
30
+ def set_hparam(self, key, value):
31
+ self.data[key] = value
32
+
33
+
34
+ # Default hyperparameters
35
+ hparams = HParams(
36
+ num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
37
+ # network
38
+ rescale=True, # Whether to rescale audio prior to preprocessing
39
+ rescaling_max=0.9, # Rescaling value
40
+ # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
41
+ # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
42
+ # Does not work if n_ffit is not multiple of hop_size!!
43
+ use_lws=False,
44
+ n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
45
+ hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
46
+ win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
47
+ sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i <filename>)
48
+ frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
49
+ # Mel and Linear spectrograms normalization/scaling and clipping
50
+ signal_normalization=True,
51
+ # Whether to normalize mel spectrograms to some predefined range (following below parameters)
52
+ allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
53
+ symmetric_mels=True,
54
+ # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
55
+ # faster and cleaner convergence)
56
+ max_abs_value=4.0,
57
+ # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
58
+ # be too big to avoid gradient explosion,
59
+ # not too small for fast convergence)
60
+ # Contribution by @begeekmyfriend
61
+ # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
62
+ # levels. Also allows for better G&L phase reconstruction)
63
+ preemphasize=True, # whether to apply filter
64
+ preemphasis=0.97, # filter coefficient.
65
+ # Limits
66
+ min_level_db=-100,
67
+ ref_level_db=20,
68
+ fmin=55,
69
+ # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
70
+ # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
71
+ fmax=7600, # To be increased/reduced depending on data.
72
+ ###################### Our training parameters #################################
73
+ img_size=96,
74
+ fps=25,
75
+ batch_size=16,
76
+ initial_learning_rate=1e-4,
77
+ nepochs=200000000000000000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
78
+ num_workers=16,
79
+ checkpoint_interval=3000,
80
+ eval_interval=3000,
81
+ save_optimizer_state=True,
82
+ syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence.
83
+ syncnet_batch_size=64,
84
+ syncnet_lr=1e-4,
85
+ syncnet_eval_interval=10000,
86
+ syncnet_checkpoint_interval=10000,
87
+ disc_wt=0.07,
88
+ disc_initial_learning_rate=1e-4,
89
+ )
90
+
91
+
92
+ def hparams_debug_string():
93
+ values = hparams.values()
94
+ hp = [
95
+ " %s: %s" % (name, values[name])
96
+ for name in sorted(values)
97
+ if name != "sentences"
98
+ ]
99
+ return "Hyperparameters:\n" + "\n".join(hp)
inference.py ADDED
@@ -0,0 +1,781 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ print("\rloading torch ", end="")
2
+ import torch
3
+
4
+ print("\rloading numpy ", end="")
5
+ import numpy as np
6
+
7
+ print("\rloading Image ", end="")
8
+ from PIL import Image
9
+
10
+ print("\rloading argparse ", end="")
11
+ import argparse
12
+
13
+ print("\rloading configparser", end="")
14
+ import configparser
15
+
16
+ print("\rloading math ", end="")
17
+ import math
18
+
19
+ print("\rloading os ", end="")
20
+ import os
21
+
22
+ print("\rloading subprocess ", end="")
23
+ import subprocess
24
+
25
+ print("\rloading pickle ", end="")
26
+ import pickle
27
+
28
+ print("\rloading cv2 ", end="")
29
+ import cv2
30
+
31
+ print("\rloading audio ", end="")
32
+ import audio
33
+
34
+ print("\rloading RetinaFace ", end="")
35
+ from batch_face import RetinaFace
36
+
37
+ print("\rloading re ", end="")
38
+ import re
39
+
40
+ print("\rloading partial ", end="")
41
+ from functools import partial
42
+
43
+ print("\rloading tqdm ", end="")
44
+ from tqdm import tqdm
45
+
46
+ print("\rloading warnings ", end="")
47
+ import warnings
48
+
49
+ warnings.filterwarnings(
50
+ "ignore", category=UserWarning, module="torchvision.transforms.functional_tensor"
51
+ )
52
+ print("\rloading upscale ", end="")
53
+ from enhance import upscale
54
+
55
+ print("\rloading load_sr ", end="")
56
+ from enhance import load_sr
57
+
58
+ print("\rloading load_model ", end="")
59
+ from easy_functions import load_model, g_colab
60
+
61
+ print("\rimports loaded! ")
62
+
63
+ device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
64
+ gpu_id = 0 if torch.cuda.is_available() else -1
65
+
66
+ if device == 'cpu':
67
+ print('Warning: No GPU detected so inference will be done on the CPU which is VERY SLOW!')
68
+ parser = argparse.ArgumentParser(
69
+ description="Inference code to lip-sync videos in the wild using Wav2Lip models"
70
+ )
71
+
72
+ parser.add_argument(
73
+ "--checkpoint_path",
74
+ type=str,
75
+ help="Name of saved checkpoint to load weights from",
76
+ required=True,
77
+ )
78
+
79
+ parser.add_argument(
80
+ "--segmentation_path",
81
+ type=str,
82
+ default="checkpoints/face_segmentation.pth",
83
+ help="Name of saved checkpoint of segmentation network",
84
+ required=False,
85
+ )
86
+
87
+ parser.add_argument(
88
+ "--face",
89
+ type=str,
90
+ help="Filepath of video/image that contains faces to use",
91
+ required=True,
92
+ )
93
+ parser.add_argument(
94
+ "--audio",
95
+ type=str,
96
+ help="Filepath of video/audio file to use as raw audio source",
97
+ required=True,
98
+ )
99
+ parser.add_argument(
100
+ "--outfile",
101
+ type=str,
102
+ help="Video path to save result. See default for an e.g.",
103
+ default="results/result_voice.mp4",
104
+ )
105
+
106
+ parser.add_argument(
107
+ "--static",
108
+ type=bool,
109
+ help="If True, then use only first video frame for inference",
110
+ default=False,
111
+ )
112
+ parser.add_argument(
113
+ "--fps",
114
+ type=float,
115
+ help="Can be specified only if input is a static image (default: 25)",
116
+ default=25.0,
117
+ required=False,
118
+ )
119
+
120
+ parser.add_argument(
121
+ "--pads",
122
+ nargs="+",
123
+ type=int,
124
+ default=[0, 10, 0, 0],
125
+ help="Padding (top, bottom, left, right). Please adjust to include chin at least",
126
+ )
127
+
128
+ parser.add_argument(
129
+ "--wav2lip_batch_size", type=int, help="Batch size for Wav2Lip model(s)", default=1
130
+ )
131
+
132
+ parser.add_argument(
133
+ "--out_height",
134
+ default=480,
135
+ type=int,
136
+ help="Output video height. Best results are obtained at 480 or 720",
137
+ )
138
+
139
+ parser.add_argument(
140
+ "--crop",
141
+ nargs="+",
142
+ type=int,
143
+ default=[0, -1, 0, -1],
144
+ help="Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. "
145
+ "Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width",
146
+ )
147
+
148
+ parser.add_argument(
149
+ "--box",
150
+ nargs="+",
151
+ type=int,
152
+ default=[-1, -1, -1, -1],
153
+ help="Specify a constant bounding box for the face. Use only as a last resort if the face is not detected."
154
+ "Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).",
155
+ )
156
+
157
+ parser.add_argument(
158
+ "--rotate",
159
+ default=False,
160
+ action="store_true",
161
+ help="Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg."
162
+ "Use if you get a flipped result, despite feeding a normal looking video",
163
+ )
164
+
165
+ parser.add_argument(
166
+ "--nosmooth",
167
+ type=str,
168
+ default=False,
169
+ help="Prevent smoothing face detections over a short temporal window",
170
+ )
171
+
172
+ parser.add_argument(
173
+ "--no_seg",
174
+ default=False,
175
+ action="store_true",
176
+ help="Prevent using face segmentation",
177
+ )
178
+
179
+ parser.add_argument(
180
+ "--no_sr", default=False, action="store_true", help="Prevent using super resolution"
181
+ )
182
+
183
+ parser.add_argument(
184
+ "--sr_model",
185
+ type=str,
186
+ default="gfpgan",
187
+ help="Name of upscaler - gfpgan or RestoreFormer",
188
+ required=False,
189
+ )
190
+
191
+ parser.add_argument(
192
+ "--fullres",
193
+ default=3,
194
+ type=int,
195
+ help="used only to determine if full res is used so that no resizing needs to be done if so",
196
+ )
197
+
198
+ parser.add_argument(
199
+ "--debug_mask",
200
+ type=str,
201
+ default=False,
202
+ help="Makes background grayscale to see the mask better",
203
+ )
204
+
205
+ parser.add_argument(
206
+ "--preview_settings", type=str, default=False, help="Processes only one frame"
207
+ )
208
+
209
+ parser.add_argument(
210
+ "--mouth_tracking",
211
+ type=str,
212
+ default=False,
213
+ help="Tracks the mouth in every frame for the mask",
214
+ )
215
+
216
+ parser.add_argument(
217
+ "--mask_dilation",
218
+ default=150,
219
+ type=float,
220
+ help="size of mask around mouth",
221
+ required=False,
222
+ )
223
+
224
+ parser.add_argument(
225
+ "--mask_feathering",
226
+ default=151,
227
+ type=int,
228
+ help="amount of feathering of mask around mouth",
229
+ required=False,
230
+ )
231
+
232
+ parser.add_argument(
233
+ "--quality",
234
+ type=str,
235
+ help="Choose between Fast, Improved and Enhanced",
236
+ default="Fast",
237
+ )
238
+
239
+ with open(os.path.join("checkpoints", "predictor.pkl"), "rb") as f:
240
+ predictor = pickle.load(f)
241
+
242
+ with open(os.path.join("checkpoints", "mouth_detector.pkl"), "rb") as f:
243
+ mouth_detector = pickle.load(f)
244
+
245
+ # creating variables to prevent failing when a face isn't detected
246
+ kernel = last_mask = x = y = w = h = None
247
+
248
+ g_colab = g_colab()
249
+
250
+ if not g_colab:
251
+ # Load the config file
252
+ config = configparser.ConfigParser()
253
+ config.read('config.ini')
254
+
255
+ # Get the value of the "preview_window" variable
256
+ preview_window = config.get('OPTIONS', 'preview_window')
257
+
258
+ all_mouth_landmarks = []
259
+
260
+ model = detector = detector_model = None
261
+
262
+ def do_load(checkpoint_path):
263
+ global model, detector, detector_model
264
+ model = load_model(checkpoint_path)
265
+ detector = RetinaFace(
266
+ gpu_id=gpu_id, model_path="checkpoints/mobilenet.pth", network="mobilenet"
267
+ )
268
+ detector_model = detector.model
269
+
270
+ def face_rect(images):
271
+ face_batch_size = 8
272
+ num_batches = math.ceil(len(images) / face_batch_size)
273
+ prev_ret = None
274
+ for i in range(num_batches):
275
+ batch = images[i * face_batch_size : (i + 1) * face_batch_size]
276
+ all_faces = detector(batch) # return faces list of all images
277
+ for faces in all_faces:
278
+ if faces:
279
+ box, landmarks, score = faces[0]
280
+ prev_ret = tuple(map(int, box))
281
+ yield prev_ret
282
+
283
+ def create_tracked_mask(img, original_img):
284
+ global kernel, last_mask, x, y, w, h # Add last_mask to global variables
285
+
286
+ # Convert color space from BGR to RGB if necessary
287
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
288
+ cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB, original_img)
289
+
290
+ # Detect face
291
+ faces = mouth_detector(img)
292
+ if len(faces) == 0:
293
+ if last_mask is not None:
294
+ last_mask = cv2.resize(last_mask, (img.shape[1], img.shape[0]))
295
+ mask = last_mask # use the last successful mask
296
+ else:
297
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
298
+ return img, None
299
+ else:
300
+ face = faces[0]
301
+ shape = predictor(img, face)
302
+
303
+ # Get points for mouth
304
+ mouth_points = np.array(
305
+ [[shape.part(i).x, shape.part(i).y] for i in range(48, 68)]
306
+ )
307
+
308
+ # Calculate bounding box dimensions
309
+ x, y, w, h = cv2.boundingRect(mouth_points)
310
+
311
+ # Set kernel size as a fraction of bounding box size
312
+ kernel_size = int(max(w, h) * args.mask_dilation)
313
+ # if kernel_size % 2 == 0: # Ensure kernel size is odd
314
+ # kernel_size += 1
315
+
316
+ # Create kernel
317
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
318
+
319
+ # Create binary mask for mouth
320
+ mask = np.zeros(img.shape[:2], dtype=np.uint8)
321
+ cv2.fillConvexPoly(mask, mouth_points, 255)
322
+
323
+ last_mask = mask # Update last_mask with the new mask
324
+
325
+ # Dilate the mask
326
+ dilated_mask = cv2.dilate(mask, kernel)
327
+
328
+ # Calculate distance transform of dilated mask
329
+ dist_transform = cv2.distanceTransform(dilated_mask, cv2.DIST_L2, 5)
330
+
331
+ # Normalize distance transform
332
+ cv2.normalize(dist_transform, dist_transform, 0, 255, cv2.NORM_MINMAX)
333
+
334
+ # Convert normalized distance transform to binary mask and convert it to uint8
335
+ _, masked_diff = cv2.threshold(dist_transform, 50, 255, cv2.THRESH_BINARY)
336
+ masked_diff = masked_diff.astype(np.uint8)
337
+
338
+ # make sure blur is an odd number
339
+ blur = args.mask_feathering
340
+ if blur % 2 == 0:
341
+ blur += 1
342
+ # Set blur size as a fraction of bounding box size
343
+ blur = int(max(w, h) * blur) # 10% of bounding box size
344
+ if blur % 2 == 0: # Ensure blur size is odd
345
+ blur += 1
346
+ masked_diff = cv2.GaussianBlur(masked_diff, (blur, blur), 0)
347
+
348
+ # Convert numpy arrays to PIL Images
349
+ input1 = Image.fromarray(img)
350
+ input2 = Image.fromarray(original_img)
351
+
352
+ # Convert mask to single channel where pixel values are from the alpha channel of the current mask
353
+ mask = Image.fromarray(masked_diff)
354
+
355
+ # Ensure images are the same size
356
+ assert input1.size == input2.size == mask.size
357
+
358
+ # Paste input1 onto input2 using the mask
359
+ input2.paste(input1, (0, 0), mask)
360
+
361
+ # Convert the final PIL Image back to a numpy array
362
+ input2 = np.array(input2)
363
+
364
+ # input2 = cv2.cvtColor(input2, cv2.COLOR_BGR2RGB)
365
+ cv2.cvtColor(input2, cv2.COLOR_BGR2RGB, input2)
366
+
367
+ return input2, mask
368
+
369
+
370
+ def create_mask(img, original_img):
371
+ global kernel, last_mask, x, y, w, h # Add last_mask to global variables
372
+
373
+ # Convert color space from BGR to RGB if necessary
374
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
375
+ cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB, original_img)
376
+
377
+ if last_mask is not None:
378
+ last_mask = np.array(last_mask) # Convert PIL Image to numpy array
379
+ last_mask = cv2.resize(last_mask, (img.shape[1], img.shape[0]))
380
+ mask = last_mask # use the last successful mask
381
+ mask = Image.fromarray(mask)
382
+
383
+ else:
384
+ # Detect face
385
+ faces = mouth_detector(img)
386
+ if len(faces) == 0:
387
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
388
+ return img, None
389
+ else:
390
+ face = faces[0]
391
+ shape = predictor(img, face)
392
+
393
+ # Get points for mouth
394
+ mouth_points = np.array(
395
+ [[shape.part(i).x, shape.part(i).y] for i in range(48, 68)]
396
+ )
397
+
398
+ # Calculate bounding box dimensions
399
+ x, y, w, h = cv2.boundingRect(mouth_points)
400
+
401
+ # Set kernel size as a fraction of bounding box size
402
+ kernel_size = int(max(w, h) * args.mask_dilation)
403
+ # if kernel_size % 2 == 0: # Ensure kernel size is odd
404
+ # kernel_size += 1
405
+
406
+ # Create kernel
407
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
408
+
409
+ # Create binary mask for mouth
410
+ mask = np.zeros(img.shape[:2], dtype=np.uint8)
411
+ cv2.fillConvexPoly(mask, mouth_points, 255)
412
+
413
+ # Dilate the mask
414
+ dilated_mask = cv2.dilate(mask, kernel)
415
+
416
+ # Calculate distance transform of dilated mask
417
+ dist_transform = cv2.distanceTransform(dilated_mask, cv2.DIST_L2, 5)
418
+
419
+ # Normalize distance transform
420
+ cv2.normalize(dist_transform, dist_transform, 0, 255, cv2.NORM_MINMAX)
421
+
422
+ # Convert normalized distance transform to binary mask and convert it to uint8
423
+ _, masked_diff = cv2.threshold(dist_transform, 50, 255, cv2.THRESH_BINARY)
424
+ masked_diff = masked_diff.astype(np.uint8)
425
+
426
+ if not args.mask_feathering == 0:
427
+ blur = args.mask_feathering
428
+ # Set blur size as a fraction of bounding box size
429
+ blur = int(max(w, h) * blur) # 10% of bounding box size
430
+ if blur % 2 == 0: # Ensure blur size is odd
431
+ blur += 1
432
+ masked_diff = cv2.GaussianBlur(masked_diff, (blur, blur), 0)
433
+
434
+ # Convert mask to single channel where pixel values are from the alpha channel of the current mask
435
+ mask = Image.fromarray(masked_diff)
436
+
437
+ last_mask = mask # Update last_mask with the final mask after dilation and feathering
438
+
439
+ # Convert numpy arrays to PIL Images
440
+ input1 = Image.fromarray(img)
441
+ input2 = Image.fromarray(original_img)
442
+
443
+ # Resize mask to match image size
444
+ # mask = Image.fromarray(mask)
445
+ mask = mask.resize(input1.size)
446
+
447
+ # Ensure images are the same size
448
+ assert input1.size == input2.size == mask.size
449
+
450
+ # Paste input1 onto input2 using the mask
451
+ input2.paste(input1, (0, 0), mask)
452
+
453
+ # Convert the final PIL Image back to a numpy array
454
+ input2 = np.array(input2)
455
+
456
+ # input2 = cv2.cvtColor(input2, cv2.COLOR_BGR2RGB)
457
+ cv2.cvtColor(input2, cv2.COLOR_BGR2RGB, input2)
458
+
459
+ return input2, mask
460
+
461
+
462
+ def get_smoothened_boxes(boxes, T):
463
+ for i in range(len(boxes)):
464
+ if i + T > len(boxes):
465
+ window = boxes[len(boxes) - T :]
466
+ else:
467
+ window = boxes[i : i + T]
468
+ boxes[i] = np.mean(window, axis=0)
469
+ return boxes
470
+
471
+ def face_detect(images, results_file="last_detected_face.pkl"):
472
+ # If results file exists, load it and return
473
+ if os.path.exists(results_file):
474
+ print("Using face detection data from last input")
475
+ with open(results_file, "rb") as f:
476
+ return pickle.load(f)
477
+
478
+ results = []
479
+ pady1, pady2, padx1, padx2 = args.pads
480
+
481
+ tqdm_partial = partial(tqdm, position=0, leave=True)
482
+ for image, (rect) in tqdm_partial(
483
+ zip(images, face_rect(images)),
484
+ total=len(images),
485
+ desc="detecting face in every frame",
486
+ ncols=100,
487
+ ):
488
+ if rect is None:
489
+ cv2.imwrite(
490
+ "temp/faulty_frame.jpg", image
491
+ ) # check this frame where the face was not detected.
492
+ raise ValueError(
493
+ "Face not detected! Ensure the video contains a face in all the frames."
494
+ )
495
+
496
+ y1 = max(0, rect[1] - pady1)
497
+ y2 = min(image.shape[0], rect[3] + pady2)
498
+ x1 = max(0, rect[0] - padx1)
499
+ x2 = min(image.shape[1], rect[2] + padx2)
500
+
501
+ results.append([x1, y1, x2, y2])
502
+
503
+
504
+ boxes = np.array(results)
505
+ if str(args.nosmooth) == "False":
506
+ boxes = get_smoothened_boxes(boxes, T=5)
507
+ results = [
508
+ [image[y1:y2, x1:x2], (y1, y2, x1, x2)]
509
+ for image, (x1, y1, x2, y2) in zip(images, boxes)
510
+ ]
511
+
512
+ # Save results to file
513
+ with open(results_file, "wb") as f:
514
+ pickle.dump(results, f)
515
+
516
+ return results
517
+
518
+
519
+ def datagen(frames, mels):
520
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
521
+ print("\r" + " " * 100, end="\r")
522
+ if args.box[0] == -1:
523
+ if not args.static:
524
+ face_det_results = face_detect(frames) # BGR2RGB for CNN face detection
525
+ else:
526
+ face_det_results = face_detect([frames[0]])
527
+ else:
528
+ print("Using the specified bounding box instead of face detection...")
529
+ y1, y2, x1, x2 = args.box
530
+ face_det_results = [[f[y1:y2, x1:x2], (y1, y2, x1, x2)] for f in frames]
531
+
532
+ for i, m in enumerate(mels):
533
+ idx = 0 if args.static else i % len(frames)
534
+ frame_to_save = frames[idx].copy()
535
+ face, coords = face_det_results[idx].copy()
536
+
537
+ face = cv2.resize(face, (args.img_size, args.img_size))
538
+
539
+ img_batch.append(face)
540
+ mel_batch.append(m)
541
+ frame_batch.append(frame_to_save)
542
+ coords_batch.append(coords)
543
+
544
+ if len(img_batch) >= args.wav2lip_batch_size:
545
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
546
+
547
+ img_masked = img_batch.copy()
548
+ img_masked[:, args.img_size // 2 :] = 0
549
+
550
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0
551
+ mel_batch = np.reshape(
552
+ mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]
553
+ )
554
+
555
+ yield img_batch, mel_batch, frame_batch, coords_batch
556
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
557
+
558
+ if len(img_batch) > 0:
559
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
560
+
561
+ img_masked = img_batch.copy()
562
+ img_masked[:, args.img_size // 2 :] = 0
563
+
564
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0
565
+ mel_batch = np.reshape(
566
+ mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]
567
+ )
568
+
569
+ yield img_batch, mel_batch, frame_batch, coords_batch
570
+
571
+
572
+ mel_step_size = 16
573
+
574
+ def _load(checkpoint_path):
575
+ if device != "cpu":
576
+ checkpoint = torch.load(checkpoint_path)
577
+ else:
578
+ checkpoint = torch.load(
579
+ checkpoint_path, map_location=lambda storage, loc: storage
580
+ )
581
+ return checkpoint
582
+
583
+
584
+ def main():
585
+ args.img_size = 96
586
+ frame_number = 11
587
+
588
+ if os.path.isfile(args.face) and args.face.split(".")[1] in ["jpg", "png", "jpeg"]:
589
+ args.static = True
590
+
591
+ if not os.path.isfile(args.face):
592
+ raise ValueError("--face argument must be a valid path to video/image file")
593
+
594
+ elif args.face.split(".")[1] in ["jpg", "png", "jpeg"]:
595
+ full_frames = [cv2.imread(args.face)]
596
+ fps = args.fps
597
+
598
+ else:
599
+ if args.fullres != 1:
600
+ print("Resizing video...")
601
+ video_stream = cv2.VideoCapture(args.face)
602
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
603
+
604
+ full_frames = []
605
+ while 1:
606
+ still_reading, frame = video_stream.read()
607
+ if not still_reading:
608
+ video_stream.release()
609
+ break
610
+
611
+ if args.fullres != 1:
612
+ aspect_ratio = frame.shape[1] / frame.shape[0]
613
+ frame = cv2.resize(
614
+ frame, (int(args.out_height * aspect_ratio), args.out_height)
615
+ )
616
+
617
+ if args.rotate:
618
+ frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)
619
+
620
+ y1, y2, x1, x2 = args.crop
621
+ if x2 == -1:
622
+ x2 = frame.shape[1]
623
+ if y2 == -1:
624
+ y2 = frame.shape[0]
625
+
626
+ frame = frame[y1:y2, x1:x2]
627
+
628
+ full_frames.append(frame)
629
+
630
+ if not args.audio.endswith(".wav"):
631
+ print("Converting audio to .wav")
632
+ subprocess.check_call(
633
+ [
634
+ "ffmpeg",
635
+ "-y",
636
+ "-loglevel",
637
+ "error",
638
+ "-i",
639
+ args.audio,
640
+ "temp/temp.wav",
641
+ ]
642
+ )
643
+ args.audio = "temp/temp.wav"
644
+
645
+ print("analysing audio...")
646
+ wav = audio.load_wav(args.audio, 16000)
647
+ mel = audio.melspectrogram(wav)
648
+
649
+ if np.isnan(mel.reshape(-1)).sum() > 0:
650
+ raise ValueError(
651
+ "Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again"
652
+ )
653
+
654
+ mel_chunks = []
655
+
656
+ mel_idx_multiplier = 80.0 / fps
657
+ i = 0
658
+ while 1:
659
+ start_idx = int(i * mel_idx_multiplier)
660
+ if start_idx + mel_step_size > len(mel[0]):
661
+ mel_chunks.append(mel[:, len(mel[0]) - mel_step_size :])
662
+ break
663
+ mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
664
+ i += 1
665
+
666
+ full_frames = full_frames[: len(mel_chunks)]
667
+ if str(args.preview_settings) == "True":
668
+ full_frames = [full_frames[0]]
669
+ mel_chunks = [mel_chunks[0]]
670
+ print(str(len(full_frames)) + " frames to process")
671
+ batch_size = args.wav2lip_batch_size
672
+ if str(args.preview_settings) == "True":
673
+ gen = datagen(full_frames, mel_chunks)
674
+ else:
675
+ gen = datagen(full_frames.copy(), mel_chunks)
676
+
677
+ for i, (img_batch, mel_batch, frames, coords) in enumerate(
678
+ tqdm(
679
+ gen,
680
+ total=int(np.ceil(float(len(mel_chunks)) / batch_size)),
681
+ desc="Processing Wav2Lip",
682
+ ncols=100,
683
+ )
684
+ ):
685
+ if i == 0:
686
+ if not args.quality == "Fast":
687
+ print(
688
+ f"mask size: {args.mask_dilation}, feathering: {args.mask_feathering}"
689
+ )
690
+ if not args.quality == "Improved":
691
+ print("Loading", args.sr_model)
692
+ run_params = load_sr()
693
+
694
+ print("Starting...")
695
+ frame_h, frame_w = full_frames[0].shape[:-1]
696
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
697
+ out = cv2.VideoWriter("temp/result.mp4", fourcc, fps, (frame_w, frame_h))
698
+
699
+ img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
700
+ mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
701
+
702
+ with torch.no_grad():
703
+ pred = model(mel_batch, img_batch)
704
+
705
+ pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.0
706
+
707
+ for p, f, c in zip(pred, frames, coords):
708
+ # cv2.imwrite('temp/f.jpg', f)
709
+
710
+ y1, y2, x1, x2 = c
711
+
712
+ if (
713
+ str(args.debug_mask) == "True"
714
+ ): # makes the background black & white so you can see the mask better
715
+ f = cv2.cvtColor(f, cv2.COLOR_BGR2GRAY)
716
+ f = cv2.cvtColor(f, cv2.COLOR_GRAY2BGR)
717
+
718
+ p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
719
+ cf = f[y1:y2, x1:x2]
720
+
721
+ if args.quality == "Enhanced":
722
+ p = upscale(p, run_params)
723
+
724
+ if args.quality in ["Enhanced", "Improved"]:
725
+ if str(args.mouth_tracking) == "True":
726
+ p, last_mask = create_tracked_mask(p, cf)
727
+ else:
728
+ p, last_mask = create_mask(p, cf)
729
+
730
+ f[y1:y2, x1:x2] = p
731
+
732
+ if not g_colab:
733
+ # Display the frame
734
+ if preview_window == "Face":
735
+ cv2.imshow("face preview - press Q to abort", p)
736
+ elif preview_window == "Full":
737
+ cv2.imshow("full preview - press Q to abort", f)
738
+ elif preview_window == "Both":
739
+ cv2.imshow("face preview - press Q to abort", p)
740
+ cv2.imshow("full preview - press Q to abort", f)
741
+
742
+ key = cv2.waitKey(1) & 0xFF
743
+ if key == ord('q'):
744
+ exit() # Exit the loop when 'Q' is pressed
745
+
746
+ if str(args.preview_settings) == "True":
747
+ cv2.imwrite("temp/preview.jpg", f)
748
+ if not g_colab:
749
+ cv2.imshow("preview - press Q to close", f)
750
+ if cv2.waitKey(-1) & 0xFF == ord('q'):
751
+ exit() # Exit the loop when 'Q' is pressed
752
+
753
+ else:
754
+ out.write(f)
755
+
756
+ # Close the window(s) when done
757
+ cv2.destroyAllWindows()
758
+
759
+ out.release()
760
+
761
+ if str(args.preview_settings) == "False":
762
+ print("converting to final video")
763
+
764
+ subprocess.check_call([
765
+ "ffmpeg",
766
+ "-y",
767
+ "-loglevel",
768
+ "error",
769
+ "-i",
770
+ "temp/result.mp4",
771
+ "-i",
772
+ args.audio,
773
+ "-c:v",
774
+ "libx264",
775
+ args.outfile
776
+ ])
777
+
778
+ if __name__ == "__main__":
779
+ args = parser.parse_args()
780
+ do_load(args.checkpoint_path)
781
+ main()
install.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version = 'v8.3'
2
+
3
+ import os
4
+ import re
5
+ import argparse
6
+ import shutil
7
+ import subprocess
8
+ from IPython.display import clear_output
9
+
10
+ from easy_functions import (format_time,
11
+ load_file_from_url,
12
+ load_model,
13
+ load_predictor)
14
+ # Get the location of the basicsr package
15
+ import os
16
+ import shutil
17
+ import subprocess
18
+ import warnings
19
+
20
+ warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.transforms.functional_tensor")
21
+
22
+ # Get the location of the basicsr package
23
+ def get_basicsr_location():
24
+ result = subprocess.run(['pip', 'show', 'basicsr'], capture_output=True, text=True)
25
+ for line in result.stdout.split('\n'):
26
+ if 'Location: ' in line:
27
+ return line.split('Location: ')[1]
28
+ return None
29
+
30
+ # Move and replace a file to the basicsr location
31
+ def move_and_replace_file_to_basicsr(file_name):
32
+ basicsr_location = get_basicsr_location()
33
+ if basicsr_location:
34
+ destination = os.path.join(basicsr_location, file_name)
35
+ # Move and replace the file
36
+ shutil.copyfile(file_name, destination)
37
+ print(f'File replaced at {destination}')
38
+ else:
39
+ print('Could not find basicsr location.')
40
+
41
+ # Example usage
42
+ file_to_replace = 'degradations.py' # Replace with your file name
43
+ move_and_replace_file_to_basicsr(file_to_replace)
44
+
45
+
46
+ from enhance import load_sr
47
+
48
+ working_directory = os.getcwd()
49
+
50
+ # download and initialize both wav2lip models
51
+ print("downloading wav2lip essentials")
52
+ load_file_from_url(
53
+ url="https://github.com/anothermartz/Easy-Wav2Lip/releases/download/Prerequesits/Wav2Lip_GAN.pth",
54
+ model_dir="checkpoints",
55
+ progress=True,
56
+ file_name="Wav2Lip_GAN.pth",
57
+ )
58
+ model = load_model(os.path.join(working_directory, "checkpoints", "Wav2Lip_GAN.pth"))
59
+ print("wav2lip_gan loaded")
60
+ load_file_from_url(
61
+ url="https://github.com/anothermartz/Easy-Wav2Lip/releases/download/Prerequesits/Wav2Lip.pth",
62
+ model_dir="checkpoints",
63
+ progress=True,
64
+ file_name="Wav2Lip.pth",
65
+ )
66
+ model = load_model(os.path.join(working_directory, "checkpoints", "Wav2Lip.pth"))
67
+ print("wav2lip loaded")
68
+
69
+ # download gfpgan files
70
+ print("downloading gfpgan essentials")
71
+ load_file_from_url(
72
+ url="https://github.com/anothermartz/Easy-Wav2Lip/releases/download/Prerequesits/GFPGANv1.4.pth",
73
+ model_dir="checkpoints",
74
+ progress=True,
75
+ file_name="GFPGANv1.4.pth",
76
+ )
77
+ load_sr()
78
+
79
+ # load face detectors
80
+ print("initializing face detectors")
81
+ load_file_from_url(
82
+ url="https://github.com/anothermartz/Easy-Wav2Lip/releases/download/Prerequesits/shape_predictor_68_face_landmarks_GTX.dat",
83
+ model_dir="checkpoints",
84
+ progress=True,
85
+ file_name="shape_predictor_68_face_landmarks_GTX.dat",
86
+ )
87
+
88
+ load_predictor()
89
+
90
+ # write a file to signify setup is done
91
+ with open("installed.txt", "w") as f:
92
+ f.write(version)
93
+ print("Installation complete!")
94
+ print(
95
+ "If you just updated from v8 - make sure to download the updated Easy-Wav2Lip.bat too!"
96
+ )
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ basicsr==1.4.2
2
+ batch-face==1.4.0
3
+ dlib==19.24.2
4
+ facexlib==0.3.0
5
+ gdown==4.7.1
6
+ gfpgan==1.3.8
7
+ imageio-ffmpeg==0.4.9
8
+ importlib-metadata==6.8.0
9
+ ipython==8.16.1
10
+ librosa==0.10.1
11
+ moviepy==1.0.3
12
+ numpy==1.26.1
13
+ opencv-python==4.8.1.78
14
+ scipy==1.11.3
15
+ --extra-index-url https://download.pytorch.org/whl/cu121
16
+ torch==2.1.0
17
+ torchaudio==2.1.0
18
+ torchvision==0.16.0
run.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import re
4
+ import argparse
5
+ from easy_functions import (format_time,
6
+ get_input_length,
7
+ get_video_details,
8
+ show_video,
9
+ g_colab)
10
+ import contextlib
11
+ import shutil
12
+ import subprocess
13
+ import time
14
+ from IPython.display import Audio, Image, clear_output, display
15
+ from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
16
+ import configparser
17
+
18
+ parser = argparse.ArgumentParser(description='SyncKing-Kong main run file')
19
+
20
+ parser.add_argument('-video_file', type=str,
21
+ help='Input video file path', required=False, default=False)
22
+ parser.add_argument('-vocal_file', type=str,
23
+ help='Input audio file path', required=False, default=False)
24
+ parser.add_argument('-output_file', type=str,
25
+ help='Output video file path', required=False, default=False)
26
+ args = parser.parse_args()
27
+
28
+ # retrieve variables from config.ini
29
+ config = configparser.ConfigParser()
30
+
31
+ config.read('config.ini')
32
+ if args.video_file:
33
+ video_file = args.video_file
34
+ else:
35
+ video_file = config['OPTIONS']['video_file']
36
+
37
+ if args.vocal_file:
38
+ vocal_file = args.vocal_file
39
+ else:
40
+ vocal_file = config['OPTIONS']['vocal_file']
41
+ quality = config['OPTIONS']['quality']
42
+ output_height = config['OPTIONS']['output_height']
43
+ wav2lip_version = config['OPTIONS']['wav2lip_version']
44
+ use_previous_tracking_data = config['OPTIONS']['use_previous_tracking_data']
45
+ nosmooth = config.getboolean('OPTIONS', 'nosmooth')
46
+ U = config.getint('PADDING', 'U')
47
+ D = config.getint('PADDING', 'D')
48
+ L = config.getint('PADDING', 'L')
49
+ R = config.getint('PADDING', 'R')
50
+ size = config.getfloat('MASK', 'size')
51
+ feathering = config.getint('MASK', 'feathering')
52
+ mouth_tracking = config.getboolean('MASK', 'mouth_tracking')
53
+ debug_mask = config.getboolean('MASK', 'debug_mask')
54
+ batch_process = config.getboolean('OTHER', 'batch_process')
55
+ output_suffix = config['OTHER']['output_suffix']
56
+ include_settings_in_suffix = config.getboolean('OTHER', 'include_settings_in_suffix')
57
+
58
+ if g_colab():
59
+ preview_input = config.getboolean("OTHER", "preview_input")
60
+ else:
61
+ preview_input = False
62
+ preview_settings = config.getboolean("OTHER", "preview_settings")
63
+ frame_to_preview = config.getint("OTHER", "frame_to_preview")
64
+
65
+ working_directory = os.getcwd()
66
+
67
+
68
+ start_time = time.time()
69
+
70
+ video_file = video_file.strip('"')
71
+ vocal_file = vocal_file.strip('"')
72
+
73
+ # check video_file exists
74
+ if video_file == "":
75
+ sys.exit(f"video_file cannot be blank")
76
+
77
+ if os.path.isdir(video_file):
78
+ sys.exit(f"{video_file} is a directory, you need to point to a file")
79
+
80
+ if not os.path.exists(video_file):
81
+ sys.exit(f"Could not find file: {video_file}")
82
+
83
+ if wav2lip_version == "Wav2Lip_GAN":
84
+ checkpoint_path = os.path.join(working_directory, "checkpoints", "Wav2Lip_GAN.pth")
85
+ else:
86
+ checkpoint_path = os.path.join(working_directory, "checkpoints", "Wav2Lip.pth")
87
+
88
+ if feathering == 3:
89
+ feathering = 5
90
+ if feathering == 2:
91
+ feathering = 3
92
+
93
+ resolution_scale = 1
94
+ res_custom = False
95
+ if output_height == "half resolution":
96
+ resolution_scale = 2
97
+ elif output_height == "full resolution":
98
+ resolution_scale = 1
99
+ else:
100
+ res_custom = True
101
+ resolution_scale = 3
102
+
103
+ in_width, in_height, in_fps, in_length = get_video_details(video_file)
104
+ out_height = round(in_height / resolution_scale)
105
+
106
+ if res_custom:
107
+ out_height = int(output_height)
108
+ fps_for_static_image = 30
109
+
110
+
111
+ if output_suffix == "" and not include_settings_in_suffix:
112
+ sys.exit(
113
+ "Current suffix settings will overwrite your input video! Please add a suffix or tick include_settings_in_suffix"
114
+ )
115
+
116
+ frame_to_preview = max(frame_to_preview - 1, 0)
117
+
118
+ if include_settings_in_suffix:
119
+ if wav2lip_version == "Wav2Lip_GAN":
120
+ output_suffix = f"{output_suffix}_GAN"
121
+ output_suffix = f"{output_suffix}_{quality}"
122
+ if output_height != "full resolution":
123
+ output_suffix = f"{output_suffix}_{out_height}"
124
+ if nosmooth:
125
+ output_suffix = f"{output_suffix}_nosmooth1"
126
+ else:
127
+ output_suffix = f"{output_suffix}_nosmooth0"
128
+ if U != 0 or D != 0 or L != 0 or R != 0:
129
+ output_suffix = f"{output_suffix}_pads-"
130
+ if U != 0:
131
+ output_suffix = f"{output_suffix}U{U}"
132
+ if D != 0:
133
+ output_suffix = f"{output_suffix}D{D}"
134
+ if L != 0:
135
+ output_suffix = f"{output_suffix}L{L}"
136
+ if R != 0:
137
+ output_suffix = f"{output_suffix}R{R}"
138
+ if quality != "fast":
139
+ output_suffix = f"{output_suffix}_mask-S{size}F{feathering}"
140
+ if mouth_tracking:
141
+ output_suffix = f"{output_suffix}_mt"
142
+ if debug_mask:
143
+ output_suffix = f"{output_suffix}_debug"
144
+ if preview_settings:
145
+ output_suffix = f"{output_suffix}_preview"
146
+
147
+
148
+ rescaleFactor = str(round(1 // resolution_scale))
149
+ pad_up = str(round(U * resolution_scale))
150
+ pad_down = str(round(D * resolution_scale))
151
+ pad_left = str(round(L * resolution_scale))
152
+ pad_right = str(round(R * resolution_scale))
153
+ ################################################################################
154
+
155
+
156
+ ######################### reconstruct input paths ##############################
157
+ # Extract each part of the path
158
+ folder, filename_with_extension = os.path.split(video_file)
159
+ filename, file_type = os.path.splitext(filename_with_extension)
160
+
161
+ # Extract filenumber if it exists
162
+ filenumber_match = re.search(r"\d+$", filename)
163
+ if filenumber_match: # if there is a filenumber - extract it
164
+ filenumber = str(filenumber_match.group())
165
+ filenamenonumber = re.sub(r"\d+$", "", filename)
166
+ else: # if there is no filenumber - make it blank
167
+ filenumber = ""
168
+ filenamenonumber = filename
169
+
170
+ # if vocal_file is blank - use the video as audio
171
+ if vocal_file == "":
172
+ vocal_file = video_file
173
+ # if not, check that the vocal_file file exists
174
+ else:
175
+ if not os.path.exists(vocal_file):
176
+ sys.exit(f"Could not find file: {vocal_file}")
177
+ if os.path.isdir(vocal_file):
178
+ sys.exit(f"{vocal_file} is a directory, you need to point to a file")
179
+
180
+ # Extract each part of the path
181
+ audio_folder, audio_filename_with_extension = os.path.split(vocal_file)
182
+ audio_filename, audio_file_type = os.path.splitext(audio_filename_with_extension)
183
+
184
+ # Extract filenumber if it exists
185
+ audio_filenumber_match = re.search(r"\d+$", audio_filename)
186
+ if audio_filenumber_match: # if there is a filenumber - extract it
187
+ audio_filenumber = str(audio_filenumber_match.group())
188
+ audio_filenamenonumber = re.sub(r"\d+$", "", audio_filename)
189
+ else: # if there is no filenumber - make it blank
190
+ audio_filenumber = ""
191
+ audio_filenamenonumber = audio_filename
192
+ ################################################################################
193
+
194
+ # set process_failed to False so that it may be set to True if one or more processings fail
195
+ process_failed = False
196
+
197
+
198
+ temp_output = os.path.join(working_directory, "temp", "output.mp4")
199
+ temp_folder = os.path.join(working_directory, "temp")
200
+
201
+ last_input_video = None
202
+ last_input_audio = None
203
+
204
+ # --------------------------Batch processing loop-------------------------------!
205
+ while True:
206
+
207
+ # construct input_video
208
+ input_video = os.path.join(folder, filenamenonumber + str(filenumber) + file_type)
209
+ input_videofile = os.path.basename(input_video)
210
+
211
+ # construct input_audio
212
+ input_audio = os.path.join(
213
+ audio_folder, audio_filenamenonumber + str(audio_filenumber) + audio_file_type
214
+ )
215
+ input_audiofile = os.path.basename(input_audio)
216
+
217
+ # see if filenames are different:
218
+ if filenamenonumber + str(filenumber) != audio_filenamenonumber + str(
219
+ audio_filenumber
220
+ ):
221
+ output_filename = (
222
+ filenamenonumber
223
+ + str(filenumber)
224
+ + "_"
225
+ + audio_filenamenonumber
226
+ + str(audio_filenumber)
227
+ )
228
+ else:
229
+ output_filename = filenamenonumber + str(filenumber)
230
+
231
+ # construct output_video
232
+ output_video = os.path.join(folder, output_filename + output_suffix + ".mp4")
233
+ output_video = os.path.normpath(output_video)
234
+ output_videofile = os.path.basename(output_video)
235
+
236
+ # remove last outputs
237
+ if os.path.exists("temp"):
238
+ shutil.rmtree("temp")
239
+ os.makedirs("temp", exist_ok=True)
240
+
241
+ # preview inputs (if enabled)
242
+ if preview_input:
243
+ print("input video:")
244
+ show_video(input_video)
245
+ if vocal_file != "":
246
+ print("input audio:")
247
+ display(Audio(input_audio))
248
+ else:
249
+ print("using", input_videofile, "for audio")
250
+ print("You may want to check now that they're the correct files!")
251
+
252
+ last_input_video = input_video
253
+ last_input_audio = input_audio
254
+ shutil.copy(input_video, temp_folder)
255
+ shutil.copy(input_audio, temp_folder)
256
+
257
+ # rename temp file to include padding or else changing padding does nothing
258
+ temp_input_video = os.path.join(temp_folder, input_videofile)
259
+ renamed_temp_input_video = os.path.join(
260
+ temp_folder, str(U) + str(D) + str(L) + str(R) + input_videofile
261
+ )
262
+ shutil.copy(temp_input_video, renamed_temp_input_video)
263
+ temp_input_video = renamed_temp_input_video
264
+ temp_input_videofile = os.path.basename(renamed_temp_input_video)
265
+ temp_input_audio = os.path.join(temp_folder, input_audiofile)
266
+
267
+ # trim video if it's longer than the audio
268
+ video_length = get_input_length(temp_input_video)
269
+ audio_length = get_input_length(temp_input_audio)
270
+
271
+ if preview_settings:
272
+ batch_process = False
273
+
274
+ preview_length_seconds = 1
275
+ converted_preview_frame = frame_to_preview / in_fps
276
+ preview_start_time = min(
277
+ converted_preview_frame, video_length - preview_length_seconds
278
+ )
279
+
280
+ preview_video_path = os.path.join(
281
+ temp_folder,
282
+ "preview_"
283
+ + str(preview_start_time)
284
+ + "_"
285
+ + str(U)
286
+ + str(D)
287
+ + str(L)
288
+ + str(R)
289
+ + input_videofile,
290
+ )
291
+ preview_audio_path = os.path.join(temp_folder, "preview_" + input_audiofile)
292
+
293
+ subprocess.call(
294
+ [
295
+ "ffmpeg",
296
+ "-loglevel",
297
+ "error",
298
+ "-i",
299
+ temp_input_video,
300
+ "-ss",
301
+ str(preview_start_time),
302
+ "-to",
303
+ str(preview_start_time + preview_length_seconds),
304
+ "-c",
305
+ "copy",
306
+ preview_video_path,
307
+ ]
308
+ )
309
+ subprocess.call(
310
+ [
311
+ "ffmpeg",
312
+ "-loglevel",
313
+ "error",
314
+ "-i",
315
+ temp_input_audio,
316
+ "-ss",
317
+ str(preview_start_time),
318
+ "-to",
319
+ str(preview_start_time + 1),
320
+ "-c",
321
+ "copy",
322
+ preview_audio_path,
323
+ ]
324
+ )
325
+ temp_input_video = preview_video_path
326
+ temp_input_audio = preview_audio_path
327
+
328
+ if video_length > audio_length:
329
+ trimmed_video_path = os.path.join(
330
+ temp_folder, "trimmed_" + temp_input_videofile
331
+ )
332
+ with open(os.devnull, "w") as devnull:
333
+ with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(
334
+ devnull
335
+ ):
336
+ ffmpeg_extract_subclip(
337
+ temp_input_video, 0, audio_length, targetname=trimmed_video_path
338
+ )
339
+ temp_input_video = trimmed_video_path
340
+ # check if face detection has already happened on this clip
341
+ last_detected_face = os.path.join(working_directory, "last_detected_face.pkl")
342
+ if os.path.isfile("last_file.txt"):
343
+ with open("last_file.txt", "r") as file:
344
+ last_file = file.readline()
345
+ if last_file != temp_input_video or use_previous_tracking_data == "False":
346
+ if os.path.isfile(last_detected_face):
347
+ os.remove(last_detected_face)
348
+
349
+ # ----------------------------Process the inputs!-----------------------------!
350
+ print(
351
+ f"Processing{' preview of' if preview_settings else ''} "
352
+ f"{input_videofile} using {input_audiofile} for audio"
353
+ )
354
+
355
+ # execute Wav2Lip & upscaler
356
+
357
+ cmd = [
358
+ sys.executable,
359
+ "inference.py",
360
+ "--face",
361
+ temp_input_video,
362
+ "--audio",
363
+ temp_input_audio,
364
+ "--outfile",
365
+ temp_output,
366
+ "--pads",
367
+ str(pad_up),
368
+ str(pad_down),
369
+ str(pad_left),
370
+ str(pad_right),
371
+ "--checkpoint_path",
372
+ checkpoint_path,
373
+ "--out_height",
374
+ str(out_height),
375
+ "--fullres",
376
+ str(resolution_scale),
377
+ "--quality",
378
+ quality,
379
+ "--mask_dilation",
380
+ str(size),
381
+ "--mask_feathering",
382
+ str(feathering),
383
+ "--nosmooth",
384
+ str(nosmooth),
385
+ "--debug_mask",
386
+ str(debug_mask),
387
+ "--preview_settings",
388
+ str(preview_settings),
389
+ "--mouth_tracking",
390
+ str(mouth_tracking),
391
+ ]
392
+
393
+ # Run the command
394
+ subprocess.run(cmd)
395
+
396
+ if preview_settings:
397
+ if os.path.isfile(os.path.join(temp_folder, "preview.jpg")):
398
+ print(f"preview successful! Check out temp/preview.jpg")
399
+ with open("last_file.txt", "w") as f:
400
+ f.write(temp_input_video)
401
+ # end processing timer and format the time it took
402
+ end_time = time.time()
403
+ elapsed_time = end_time - start_time
404
+ formatted_setup_time = format_time(elapsed_time)
405
+ print(f"Execution time: {formatted_setup_time}")
406
+ break
407
+
408
+ else:
409
+ print(f"Processing failed! :( see line above 👆")
410
+ print("Consider searching the issues tab on the github:")
411
+ print("https://github.com/anothermartz/Easy-Wav2Lip/issues")
412
+ exit()
413
+
414
+ # rename temp file and move to correct directory
415
+ if os.path.isfile(temp_output):
416
+ if os.path.isfile(output_video):
417
+ os.remove(output_video)
418
+ shutil.copy(temp_output, output_video)
419
+ # show output video
420
+ with open("last_file.txt", "w") as f:
421
+ f.write(temp_input_video)
422
+ print(f"{output_filename} successfully lip synced! It will be found here:")
423
+ print(output_video)
424
+
425
+ # end processing timer and format the time it took
426
+ end_time = time.time()
427
+ elapsed_time = end_time - start_time
428
+ formatted_setup_time = format_time(elapsed_time)
429
+ print(f"Execution time: {formatted_setup_time}")
430
+
431
+ else:
432
+ print(f"Processing failed! :( see line above 👆")
433
+ print("Consider searching the issues tab on the github:")
434
+ print("https://github.com/anothermartz/Easy-Wav2Lip/issues")
435
+ process_failed = True
436
+
437
+ if batch_process == False:
438
+ if process_failed:
439
+ exit()
440
+ else:
441
+ break
442
+
443
+ elif filenumber == "" and audio_filenumber == "":
444
+ print("Files not set for batch processing")
445
+ break
446
+
447
+ # -----------------------------Batch Processing!------------------------------!
448
+ if filenumber != "": # if video has a filenumber
449
+ match = re.search(r"\d+", filenumber)
450
+ # add 1 to video filenumber
451
+ filenumber = (
452
+ f"{filenumber[:match.start()]}{int(match.group())+1:0{len(match.group())}d}"
453
+ )
454
+
455
+ if audio_filenumber != "": # if audio has a filenumber
456
+ match = re.search(r"\d+", audio_filenumber)
457
+ # add 1 to audio filenumber
458
+ audio_filenumber = f"{audio_filenumber[:match.start()]}{int(match.group())+1:0{len(match.group())}d}"
459
+
460
+ # construct input_video
461
+ input_video = os.path.join(folder, filenamenonumber + str(filenumber) + file_type)
462
+ input_videofile = os.path.basename(input_video)
463
+ # construct input_audio
464
+ input_audio = os.path.join(
465
+ audio_folder, audio_filenamenonumber + str(audio_filenumber) + audio_file_type
466
+ )
467
+ input_audiofile = os.path.basename(input_audio)
468
+
469
+ # now check which input files exist and what to do for each scenario
470
+
471
+ # both +1 files exist - continue processing
472
+ if os.path.exists(input_video) and os.path.exists(input_audio):
473
+ continue
474
+
475
+ # video +1 only - continue with last audio file
476
+ if os.path.exists(input_video) and input_video != last_input_video:
477
+ if audio_filenumber != "": # if audio has a filenumber
478
+ match = re.search(r"\d+", audio_filenumber)
479
+ # take 1 from audio filenumber
480
+ audio_filenumber = f"{audio_filenumber[:match.start()]}{int(match.group())-1:0{len(match.group())}d}"
481
+ continue
482
+
483
+ # audio +1 only - continue with last video file
484
+ if os.path.exists(input_audio) and input_audio != last_input_audio:
485
+ if filenumber != "": # if video has a filenumber
486
+ match = re.search(r"\d+", filenumber)
487
+ # take 1 from video filenumber
488
+ filenumber = f"{filenumber[:match.start()]}{int(match.group())-1:0{len(match.group())}d}"
489
+ continue
490
+
491
+ # neither +1 files exist or current files already processed - finish processing
492
+ print("Finished all sequentially numbered files")
493
+ if process_failed:
494
+ sys.exit("Processing failed on at least one video")
495
+ else:
496
+ break