|
|
|
|
|
|
|
|
|
|
|
import re |
|
import cv2 |
|
import sox |
|
import wget |
|
import yt_dlp |
|
import ffmpeg |
|
import pickle |
|
import tarfile |
|
import warnings |
|
import numpy as np |
|
import pandas as pd |
|
from tqdm import tqdm |
|
from skimage import transform |
|
from collections import deque |
|
from urllib.error import HTTPError |
|
|
|
|
|
def is_empty(path): |
|
return any(path.iterdir()) == False |
|
|
|
|
|
def read_txt_file(txt_filepath): |
|
with open(txt_filepath) as fin: |
|
return (line.strip() for line in fin.readlines()) |
|
|
|
|
|
def write_txt_file(lines, out_txt_filepath): |
|
with open(out_txt_filepath, "w") as fout: |
|
fout.writelines("\n".join([ln.strip() for ln in lines])) |
|
|
|
|
|
def normalize_text(text): |
|
PUNCS = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~؟؛,’‘×÷" |
|
|
|
text = re.sub(r"\([^)]*\)", "", text) |
|
|
|
text = text.translate(str.maketrans("", "", PUNCS)) |
|
|
|
text = text.lower() |
|
return text.strip() |
|
|
|
|
|
def download_file(url, download_path): |
|
filename = url.rpartition("/")[-1] |
|
if not (download_path / filename).exists(): |
|
try: |
|
|
|
print(f"Downloading {filename} from {url}") |
|
custom_bar = ( |
|
lambda current, total, width=80: wget.bar_adaptive( |
|
round(current / 1024 / 1024, 2), |
|
round(total / 1024 / 1024, 2), |
|
width, |
|
) |
|
+ " MB" |
|
) |
|
wget.download(url, out=str(download_path / filename), bar=custom_bar) |
|
except Exception as e: |
|
message = f"Downloading {filename} failed!" |
|
raise HTTPError(e.url, e.code, message, e.hdrs, e.fp) |
|
return True |
|
|
|
|
|
def extract_tgz(tgz_filepath, extract_path, out_filename=None): |
|
if not tgz_filepath.exists(): |
|
raise FileNotFoundError(f"{tgz_filepath} is not found!!") |
|
tgz_filename = tgz_filepath.name |
|
tgz_object = tarfile.open(tgz_filepath) |
|
if not out_filename: |
|
out_filename = tgz_object.getnames()[0] |
|
|
|
if not (extract_path / out_filename).exists(): |
|
for mem in tqdm(tgz_object.getmembers(), desc=f"Extracting {tgz_filename}"): |
|
out_filepath = extract_path / mem.get_info()["name"] |
|
if mem.isfile() and not out_filepath.exists(): |
|
tgz_object.extract(mem, path=extract_path) |
|
tgz_object.close() |
|
|
|
|
|
def download_extract_file_if_not(url, tgz_filepath, download_filename): |
|
download_path = tgz_filepath.parent |
|
if not tgz_filepath.exists(): |
|
|
|
download_file(url, download_path) |
|
|
|
extract_tgz(tgz_filepath, download_path, download_filename) |
|
|
|
|
|
def load_meanface_metadata(metadata_path): |
|
mean_face_filepath = metadata_path / "20words_mean_face.npy" |
|
if not mean_face_filepath.exists(): |
|
download_file( |
|
"https://dl.fbaipublicfiles.com/muavic/metadata/20words_mean_face.npy", |
|
metadata_path, |
|
) |
|
return np.load(mean_face_filepath) |
|
|
|
|
|
def load_video_metadata(filepath): |
|
if not filepath.exists(): |
|
|
|
lang_dir = filepath.parent.parent |
|
lang = lang_dir.name |
|
tgz_filepath = lang_dir.parent / f"{lang}_metadata.tgz" |
|
download_extract_file_if_not( |
|
url=f"https://dl.fbaipublicfiles.com/muavic/metadata/{lang}_metadata.tgz", |
|
tgz_filepath=tgz_filepath, |
|
download_filename=lang |
|
) |
|
if not filepath.exists(): |
|
|
|
return None |
|
assert filepath.exists(), f"{filepath} should've been downloaded!" |
|
with open(filepath, "rb") as fin: |
|
metadata = pickle.load(fin) |
|
return metadata |
|
|
|
|
|
def download_video_from_youtube(download_path, yt_id): |
|
"""Downloads a video from YouTube given its id on YouTube""" |
|
video_out_path = download_path / f"{yt_id}.mp4" |
|
if video_out_path.exists(): |
|
downloaded = True |
|
else: |
|
url = f"https://www.youtube.com/watch?v={yt_id}" |
|
|
|
|
|
ydl_opts = {"quiet": True, "format": "mp4", "outtmpl": str(video_out_path)} |
|
with yt_dlp.YoutubeDL(ydl_opts) as ydl: |
|
try: |
|
ydl.download([url]) |
|
downloaded = True |
|
except yt_dlp.utils.DownloadError: |
|
downloaded = False |
|
return downloaded |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resize_frames(input_frames, new_size): |
|
resized_frames = [] |
|
for frame in input_frames: |
|
try: |
|
resized_frames.append(cv2.resize(frame, new_size)) |
|
except: |
|
pass |
|
return resized_frames |
|
|
|
|
|
def get_audio_duration(audio_filepath): |
|
return sox.file_info.duration(audio_filepath) |
|
|
|
|
|
def get_video_duration(video_filepath): |
|
try: |
|
streams = ffmpeg.probe(video_filepath)["streams"] |
|
for stream in streams: |
|
if stream["codec_type"] == "video": |
|
return float(stream["duration"]) |
|
except: |
|
warnings.warn(f"Video file: `{video_filepath}` is corrupted... skipping!!") |
|
return -1 |
|
|
|
|
|
def get_video_resolution(video_filepath): |
|
for stream in ffmpeg.probe(video_filepath)["streams"]: |
|
if stream["codec_type"] == "video": |
|
height = int(stream["height"]) |
|
width = int(stream["width"]) |
|
return height, width |
|
raise TypeError(f"Input file: {video_filepath} doesn't have video stream!") |
|
|
|
|
|
def get_audio_video_info(audio_path, video_path, fid): |
|
audio_filepath = audio_path / f"{fid}.wav" |
|
video_filepath = video_path / f"{fid}.mp4" |
|
audio_frames = ( |
|
int(get_audio_duration(audio_filepath) * 16_000) |
|
if audio_filepath.exists() |
|
else -1 |
|
) |
|
video_frames = ( |
|
int(get_video_duration(video_filepath) * 25) if video_filepath.exists() else -1 |
|
) |
|
return { |
|
"id": fid, |
|
"video": str(video_filepath), |
|
"audio": str(audio_filepath), |
|
"video_frames": video_frames, |
|
"audio_samples": audio_frames, |
|
} |
|
|
|
|
|
def split_video_to_frames(video_filepath, fstart=None, fend=None, out_fps=25): |
|
|
|
|
|
width, height = get_video_resolution(video_filepath) |
|
video_stream = ffmpeg.input(str(video_filepath)).video.filter("fps", fps=out_fps) |
|
channels = 3 |
|
try: |
|
if fstart is not None and fend is not None: |
|
process = ( |
|
video_stream.trim(start_frame=fstart, end_frame=fend) |
|
.setpts("PTS-STARTPTS") |
|
.output("pipe:", format="rawvideo", pix_fmt="bgr24") |
|
.run_async(pipe_stdout=True, quiet=True) |
|
) |
|
frames_counter = 0 |
|
while frames_counter < fend - fstart: |
|
in_bytes = process.stdout.read(width * height * channels) |
|
in_frame = np.frombuffer(in_bytes, np.uint8).reshape( |
|
width, height, channels |
|
) |
|
yield in_frame |
|
frames_counter += 1 |
|
else: |
|
process = ( |
|
video_stream.setpts("PTS-STARTPTS") |
|
.output("pipe:", format="rawvideo", pix_fmt="bgr24") |
|
.run_async(pipe_stdout=True, quiet=True) |
|
) |
|
while True: |
|
in_bytes = process.stdout.read(width * height * channels) |
|
if not in_bytes: |
|
break |
|
in_frame = np.frombuffer(in_bytes, np.uint8).reshape( |
|
width, height, channels |
|
) |
|
yield in_frame |
|
|
|
finally: |
|
process.stdout.close() |
|
process.wait() |
|
|
|
|
|
def save_video(frames, out_filepath, fps, vcodec="libx264"): |
|
if len(frames) == 0: |
|
warnings.warn( |
|
f"Video segment `{out_filepath.stem}` has no metadata..." + |
|
" skipping!!" |
|
) |
|
return |
|
height, width, _ = frames[0].shape |
|
process = ( |
|
ffmpeg.input( |
|
"pipe:", format="rawvideo", pix_fmt="bgr24", s="{}x{}".format(width, height) |
|
) |
|
.output(str(out_filepath), pix_fmt="bgr24", vcodec=vcodec, r=fps) |
|
.overwrite_output() |
|
.run_async(pipe_stdin=True, quiet=True) |
|
) |
|
for _, frame in enumerate(frames): |
|
try: |
|
process.stdin.write(frame.astype(np.uint8).tobytes()) |
|
except: |
|
print(process.stderr.read()) |
|
process.stdin.close() |
|
process.wait() |
|
|
|
|
|
def load_video(filename): |
|
cap = cv2.VideoCapture(filename) |
|
while cap.isOpened(): |
|
ret, frame = cap.read() |
|
if ret: |
|
yield frame |
|
else: |
|
break |
|
cap.release() |
|
|
|
|
|
def warp_img(src, dst, img, std_size): |
|
tform = transform.estimate_transform( |
|
"similarity", src, dst |
|
) |
|
warped = transform.warp( |
|
img, inverse_map=tform.inverse, output_shape=std_size |
|
) |
|
warped = warped * 255 |
|
warped = warped.astype("uint8") |
|
return warped, tform |
|
|
|
|
|
def apply_transform(trans, img, std_size): |
|
warped = transform.warp(img, inverse_map=trans.inverse, output_shape=std_size) |
|
warped = warped * 255 |
|
warped = warped.astype("uint8") |
|
return warped |
|
|
|
|
|
def cut_patch(img, metadata, height, width, threshold=5): |
|
center_x, center_y = np.mean(metadata, axis=0) |
|
if center_y - height < 0: |
|
center_y = height |
|
if center_y - height < 0 - threshold: |
|
raise Exception("too much bias in height") |
|
if center_x - width < 0: |
|
center_x = width |
|
if center_x - width < 0 - threshold: |
|
raise Exception("too much bias in width") |
|
|
|
if center_y + height > img.shape[0]: |
|
center_y = img.shape[0] - height |
|
if center_y + height > img.shape[0] + threshold: |
|
raise Exception("too much bias in height") |
|
if center_x + width > img.shape[1]: |
|
center_x = img.shape[1] - width |
|
if center_x + width > img.shape[1] + threshold: |
|
raise Exception("too much bias in width") |
|
|
|
cutted_img = np.copy( |
|
img[ |
|
int(round(center_y) - round(height)) : int(round(center_y) + round(height)), |
|
int(round(center_x) - round(width)) : int(round(center_x) + round(width)), |
|
] |
|
) |
|
return cutted_img |
|
|
|
|
|
def crop_patch( |
|
video_frames, |
|
num_frames, |
|
metadata, |
|
mean_face_metadata, |
|
std_size=(256, 256), |
|
window_margin=12, |
|
start_idx=48, |
|
stop_idx=68, |
|
crop_height=96, |
|
crop_width=96, |
|
): |
|
"""Crop mouth patch""" |
|
stablePntsIDs = [33, 36, 39, 42, 45] |
|
margin = min(num_frames, window_margin) |
|
q_frame, q_metadata = deque(), deque() |
|
sequence = [] |
|
for frame_idx, frame in enumerate(video_frames): |
|
if frame_idx >= len(metadata): |
|
break |
|
q_metadata.append(metadata[frame_idx]) |
|
q_frame.append(frame) |
|
if len(q_frame) == margin: |
|
smoothed_metadata = np.mean(q_metadata, axis=0) |
|
cur_metadata = q_metadata.popleft() |
|
cur_frame = q_frame.popleft() |
|
|
|
trans_frame, trans = warp_img( |
|
smoothed_metadata[stablePntsIDs, :], |
|
mean_face_metadata[stablePntsIDs, :], |
|
cur_frame, |
|
std_size, |
|
) |
|
trans_metadata = trans(cur_metadata) |
|
|
|
sequence.append( |
|
cut_patch( |
|
trans_frame, |
|
trans_metadata[start_idx:stop_idx], |
|
crop_height // 2, |
|
crop_width // 2, |
|
) |
|
) |
|
|
|
while q_frame: |
|
cur_frame = q_frame.popleft() |
|
|
|
trans_frame = apply_transform(trans, cur_frame, std_size) |
|
|
|
trans_metadata = trans(q_metadata.popleft()) |
|
|
|
sequence.append( |
|
cut_patch( |
|
trans_frame, |
|
trans_metadata[start_idx:stop_idx], |
|
crop_height // 2, |
|
crop_width // 2, |
|
) |
|
) |
|
return sequence |
|
|
|
|
|
def read_av_manifest(tsv_filepath): |
|
with open(tsv_filepath) as fin: |
|
res = [] |
|
for ln in fin.readlines()[1:]: |
|
id_, video, audio, video_frames, audio_samples = ln.strip().split("\t") |
|
res.append( |
|
{ |
|
"id": id_, |
|
"video": video, |
|
"audio": audio, |
|
"video_frames": video_frames, |
|
"audio_samples": audio_samples, |
|
} |
|
) |
|
df = pd.DataFrame(res) |
|
df["video_frames"] = df["video_frames"].astype(int) |
|
df["audio_samples"] = df["audio_samples"].astype(int) |
|
return df |
|
|
|
|
|
def write_av_manifest(df, out_filepath): |
|
with open(out_filepath, "w") as fout: |
|
fout.write("/\n") |
|
df.to_csv(out_filepath, sep="\t", header=False, index=False, mode="a") |
|
|