import os |
import cv2 |
import sys |
import time |
import random |
import shutil |
import hashlib |
import logging |
import argparse |
from tqdm import tqdm |
from pathlib import Path |
from ffmpy import FFmpeg |
import glob |
import pdb |
import torchaudio |
import random |
import torch |
import numpy as np |
from scipy.io import wavfile |
from jiwer import wer, cer |
import json |
from faster_whisper import WhisperModel |
import shutil |
random_seed=1234 |
torch.manual_seed(random_seed) |
torch.cuda.manual_seed(random_seed) |
torch.cuda.manual_seed_all(random_seed) |
torch.backends.cudnn.deterministic = True |
torch.backends.cudnn.benchmark = False |
np.random.seed(random_seed) |
random.seed(random_seed) |
sys.path.insert(0, str(Path(__file__).parent.parent)) |
from demo_utils import *sh scripts/demo.sh multi |
from utils import ( |
split_video_to_frames, |
resize_frames, |
crop_patch, |
save_video, |
) |
logging.basicConfig(level=logging.DEBUG) |
logger = logging.getLogger(__name__) |
def detect_landmark(image): |
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) |
rects = DETECTOR(gray, 1) |
coords = None |
for (_, rect) in enumerate(rects): |
shape = PREDICTOR(gray, rect) |
coords = np.zeros((68, 2), dtype=np.int32) |
for i in range(0, 68): |
coords[i] = (shape.part(i).x, shape.part(i).y) |
return coords |
@track_time |
def extract_lip_movement( |
webcam_video, |
in_video_filepath, |
out_lip_filepath, |
num_workers |
): |
def copy_video_if_ready(webcam_video, out_path): |
with open(webcam_video, 'rb') as fin: |
curr_md5hash = hashlib.md5(fin.read()).hexdigest() |
if curr_md5hash in VIDEOS_CACHE: |
dst_path = VIDEOS_CACHE[curr_md5hash] |
shutil.copy(dst_path / "video.mp4", out_path) |
shutil.copy(dst_path / "lip_movement.mp4", out_path) |
shutil.copy(dst_path / "raw_video.md5", out_path) |
return True |
else: |
VIDEOS_CACHE[curr_md5hash] = out_path |
with open(out_path / "raw_video.md5", 'w') as fout: |
fout.write(curr_md5hash) |
return False |
''' |
if copy_video_if_ready(webcam_video, in_video_filepath.parent): |
logger.info("Skip video processing; Loading the cached one!!") |
return |
''' |
logger.info("Adjust video framerate to 25") |
if not os.path.isfile(in_video_filepath): |
FFmpeg( |
inputs={webcam_video: None}, |
outputs={in_video_filepath: "-v quiet -filter:v fps=fps=25 -vf scale=640:480"}, |
).run() |
logger.info("Converting video into frames") |
frames = list(split_video_to_frames(in_video_filepath)) |
logger.info("Extract face landmarks from video frames") |
landmarks = [ |
detect_landmark(frame) |
for frame in tqdm(frames, desc="Detecting Lip Movement") |
] |
invalid_landmarks_ratio = sum(lnd is None for lnd in landmarks) / len(landmarks) |
logger.info(f"Current invalid frame ratio ({invalid_landmarks_ratio}) ") |
if invalid_landmarks_ratio > MAX_MISSING_FRAMES_RATIO: |
logging.info( |
"Invalid frame ratio exceeded maximum allowed ratio!! " + |
"Starting resizing the recorded video!!" |
) |
sequence = resize_frames(frames) |
else: |
if invalid_landmarks_ratio != 0: |
logger.info("Linearly-interpolate invalid landmarks") |
continuous_landmarks = landmarks_interpolate(landmarks) |
else: |
continuous_landmarks = landmarks |
logger.info("Cropping the mouth region.") |
sequence = crop_patch( |
frames, |
len(frames), |
continuous_landmarks, |
) |
save_video(sequence, out_lip_filepath, fps=25) |
def process_input_video( |
model_type: str, |
input_video_path: str, |
noise_snr: int, |
noise_type: str, |
outpath: str, |
): |
if input_video_path is None: |
raise IOError( |
"Gradio didn't record the video. Refresh the web page, please!!" |
) |
audio_filepath = outpath / "audio.wav" |
video_filepath = outpath / "video.mp4" |
noisy_audio_filepath = outpath / "noisy_audio.wav" |
lip_video_filepath = outpath / "lip_movement.mp4" |
if not os.path.isfile(video_filepath) and not os.path.isfile(lip_video_filepath): |
extract_lip_movement( |
input_video_path, video_filepath, lip_video_filepath, |
num_workers=min(os.cpu_count(), 5) |
) |
logger.info(f"Mixing audio with `{noise_type}` noise (SNR={noise_snr}).") |
noise_wav_files = NOISE[noise_type] |
noise_wav_file = noise_wav_files[random.randint(0, len(noise_wav_files) - 1)] |
logger.debug(f"Noise Wav used is {noise_wav_file}") |
mixed = mix_audio_with_noise( |
input_video_path, audio_filepath, noisy_audio_filepath, |
noise_wav_file, noise_snr |
) |
logger.info("Adding noisy audio with the lip-movement video.") |
noisy_lip_filepath = outpath / "noisy_lip_movement.mp4" |
FFmpeg( |
inputs={noisy_audio_filepath: None, lip_video_filepath: None}, |
outputs={noisy_lip_filepath: "-v quiet -c:v copy -c:a aac"}, |
).run() |
av_text = infer_av_hubert( |
AV_RESOURCES[model_type]["model"], |
AV_RESOURCES[model_type]["task"], |
AV_RESOURCES[model_type]["generator"], |
lip_video_filepath, |
noisy_audio_filepath, |
duration=len(mixed) / 16000 |
) |
logger.info(f"Av-HuBERT Output: {av_text}") |
logger.info("Summary:") |
for k, v in TIME_TRACKER.items(): |
logger.info(f'Function {k} executed in {v} seconds') |
logger.info(30 * '=' + " Done! " + '=' * 30) |
return (str(noisy_lip_filepath), av_text) |
def test_WER( |
model_type: str, |
input_video_path: str, |
gt_text: str, |
noise_type: str, |
model_name: str, |
noise_name : str, |
noise_wav_file : str, |
outpath: str, |
file_name: str, |
is_valid: dict, |
): |
if input_video_path is None: |
raise IOError( |
"Gradio didn't record the video. Refresh the web page, please!!" |
) |
out_filepath = outpath / model_name/file_name |
out_filepath.mkdir(parents=True, exist_ok=True) |
audio_filepath = out_filepath/ "audio.wav" |
video_filepath = out_filepath/ "video.mp4" |
noisy_audio_path = outpath / model_name / noise_type / noise_name |
noisy_audio_path.mkdir(parents=True, exist_ok=True) |
lip_video_filepath = out_filepath / "lip_movement.mp4" |
if not os.path.isfile(lip_video_filepath): |
extract_lip_movement( |
input_video_path, video_filepath, lip_video_filepath, |
num_workers=min(os.cpu_count(), 5) |
) |
if not os.path.isfile(audio_filepath): |
FFmpeg( |
inputs={input_video_path: None}, |
outputs={audio_filepath: "-v quiet -vn -acodec pcm_s16le -ar 16000 -ac 1"}, |
).run() |
sr, audio = wavfile.read(audio_filepath) |
_, noise = wavfile.read(noise_wav_file) |
wer_temp = [] |
cer_temp = [] |
''' |
## original wer and edit distance |
origin_av_text = infer_av_hubert( |
AV_RESOURCES[model_type]["model"], |
AV_RESOURCES[model_type]["task"], |
AV_RESOURCES[model_type]["generator"], |
lip_video_filepath, |
audio_filepath, |
duration=len(audio) / 16000 |
) |
word_error_rate = wer(gt_text.lower().replace('\n', ''), origin_av_text.lower().replace('\n', '')) |
character_error_rate = cer(gt_text.lower().replace('\n', ''), origin_av_text.lower().replace('\n', '')) |
wer_temp.append(word_error_rate) |
cer_temp.append(character_error_rate) |
''' |
for ns in [-7.5, -10]: |
snr_name = "snr_"+ str(ns) |
noisy_audio_ns_path = noisy_audio_path / snr_name / file_name |
noisy_audio_ns_path.mkdir(parents=True, exist_ok=True) |
noisy_audio_ns_filepath = noisy_audio_ns_path / "noisy_audio.wav" |
mixed = add_noise(audio, noise, ns) |
if not os.path.isfile(noisy_audio_ns_filepath): |
wavfile.write(noisy_audio_ns_filepath, sr, mixed) |
noisy_lip_filepath = noisy_audio_ns_path / "noisy_lip_movement.mp4" |
if not os.path.isfile(noisy_lip_filepath): |
FFmpeg( |
inputs={noisy_audio_ns_filepath: None, lip_video_filepath: None}, |
outputs={noisy_lip_filepath: "-v quiet -c:v copy -c:a aac"}, |
).run() |
av_text = infer_av_hubert( |
AV_RESOURCES[model_type]["model"], |
AV_RESOURCES[model_type]["task"], |
AV_RESOURCES[model_type]["generator"], |
lip_video_filepath, |
noisy_audio_ns_filepath, |
duration=len(mixed) / 16000 |
) |
av_text = av_text.replace('.','').replace(',','').replace('!','').replace(';','').replace(':','').replace('?','').replace('/','').lower().replace('\n', '').strip() |
gt_text = gt_text.replace('.','').replace(',','').replace('!','').replace(';','').replace(':','').replace('?','').replace('/','').lower().replace('\n', '').strip() |
word_error_rate = wer(gt_text, av_text) |
character_error_rate = cer(gt_text, av_text) |
print(f"av_text : {av_text}") |
print(f"gt_text : {gt_text}") |
''' |
if sum(is_valid.values())>=51 and word_error_rate >= 1.0 and ns == -7.5 and (model_name in ["MultiTalk"]): |
is_valid[file_name] = 0 |
''' |
print( |
f"file_name : {file_name}, snr: {str(ns)}, word_error_rate : {word_error_rate}, character_error_rate : {character_error_rate}, is_valid[file_name] : {is_valid[file_name]}") |
wer_temp.append(word_error_rate) |
cer_temp.append(character_error_rate) |
shutil.rmtree(noisy_audio_ns_path) |
return wer_temp, cer_temp, is_valid |
if __name__ == "__main__": |
parser = argparse.ArgumentParser() |
parser.add_argument( |
"--avhubert-path", type=Path, required=False, default="./av_hubert/avhubert", |
help="Relative/Absolute path where avhubert repo is located." |
) |
parser.add_argument( |
"--work-dir", type=Path, required=True, |
default="/local_data_2/chaeyeon/interspeech2024/avlr", |
help="work directory for avlr evaluation" |
) |
parser.add_argument( |
"--language", type=str, required=True, |
default="English", |
help="evaluation language" |
) |
parser.add_argument( |
"--model-name", type=str, required=True, |
default="all", |
help="model name" |
) |
parser.add_argument( |
"--exp-name", type=str, required=True, |
default="base", |
help="experiment name" |
) |
args = parser.parse_args() |
logger.info("Loading noise samples..") |
start_time = time.time() |
work_path = args.work_dir / args.language |
input_path = work_path / "inputs" |
output_path = work_path / "outputs" |
lang_map = {'Arabic': 'ar', 'English': 'en', 'German': 'de', 'Italian': 'it', 'Portuguese': 'pt', 'Spanish': 'es', |
'French': 'fr', 'Greek': 'el', 'Russian': 'ru'} |
checkpoint_path = work_path / "checkpoints" |
av_model_path = os.path.join(checkpoint_path , lang_map[args.language]+"_avsr") |
output_path.mkdir(parents=True, exist_ok=True) |
noise_path = args.work_dir / "noise_samples" |
NOISE = load_noise_samples(noise_path) |
logger.info("Loading AV models!") |
if not checkpoint_path.exists(): |
raise ValueError( |
f"av-models-path: `{checkpoint_path}` doesn't exist!!" |
) |
utils.import_user_module( |
argparse.Namespace(user_dir=str(args.avhubert_path)) |
) |
AV_RESOURCES = load_av_models(checkpoint_path) |
logger.info("Loading models responsible for preprocessing!") |
metadata_path = args.work_dir / "metadata" |
load_needed_models_for_lip_movement(metadata_path) |
) |
logger.info("Done loading!") |
logger.info("Caching previously recorded videos!") |
for hash_path in output_path.rglob("*.md5"): |
with open(hash_path) as fin: |
md5hash = fin.read() |
VIDEOS_CACHE[md5hash] = hash_path.parent |
if args.model_name == "all": |
model_names = ["MultiTalk", "VOCA", "FaceFormer", "CodeTalker"] |
else: |
model_names = [args.model_name] |
wav_path = input_path / "wav" |
model = WhisperModel("large-v3", device="cpu", compute_type="int8") |
wer_path = work_path / f'wer_{args.exp_name}.json' |
cer_path = work_path / f'cer_{args.exp_name}.json' |
is_valid_path = work_path / f'is_valid_{args.exp_name}.json' |
noise_types = ["indoors", "indoors", "music"] |
noise_names = ["dog-playing", "kids-playing", "leave_it_to_the_experts"] |
snr_values = ['-7.5', '-10'] |
total_word_error_rate = {} |
total_character_error_rate = {} |
wer_results = {} |
cer_results = {} |
video_path = input_path / model_names[0] |
video_lists = glob.glob(os.path.join(video_path, "*.mp4")) |
sorted_video_lists = sorted(video_lists) |
text_path = input_path / "text" |
text_path.mkdir(parents=True, exist_ok=True) |
text_lists = glob.glob(os.path.join(text_path, "*.txt")) |
if len(text_lists) != len(video_lists): |
for vid in sorted_video_lists: |
file_name = vid.split("/")[-1].split(".")[0] |
wav_file = os.path.join(wav_path, file_name + ".wav") |
segments, info = model.transcribe(audio=wav_file, language=lang_map[args.language], |
beam_size=5) |
text = '' |
for segment in segments: |
text = text + segment.text |
text_file = os.path.join(text_path, file_name + ".txt") |
with open(text_file, 'w') as f: |
f.write(text.replace('.','').replace(',','').replace('!','').replace(';','').replace(':','').replace('?','').strip()) |
f.close() |
start_eval = time.time() |
print(f"Pseudo gt text made in {start_eval - start_time} secs.") |
if args.language in ['Greek', 'Italian']: |
is_valid_path = work_path / f'is_valid_base_wo_self.json' |
elif args.language in ['English', 'French', 'German']: |
is_valid_path = work_path / f'is_valid_base.json' |
with open(is_valid_path, 'r') as f: |
is_valid = json.load(f) |
f.close() |
''' |
for vid in sorted_video_lists : |
file_name = vid.split("/")[-1].split(".")[0] |
is_valid[file_name] = 1 |
''' |
for model_name in model_names: |
total_word_error_rate[model_name] = {} |
total_character_error_rate[model_name] = {} |
for noise_name in noise_names: |
total_word_error_rate[model_name][noise_name]={"-7.5":0.0, "-10":0.0} |
total_character_error_rate[model_name][noise_name] = {"-7.5": 0.0, "-10": 0.0} |
video_path = input_path / model_name |
video_lists = glob.glob(os.path.join(video_path, "*.mp4")) |
sorted_video_lists = sorted(video_lists) |
for vid in sorted_video_lists: |
file_name = vid.split("/")[-1].split(".")[0] |
if args.language == "French": |
file_name=file_name.replace('F','f',1) |
elif args.language == "English": |
file_name = file_name.replace('E', 'e', 1) |
elif args.language == "Italian": |
file_name = file_name.replace('I', 'i', 1) |
elif args.language == "Greek": |
file_name = file_name.replace('G', 'g', 1) |
if is_valid[file_name] == 0: |
continue |
text_file = os.path.join(text_path, file_name + ".txt") |
f = open(text_file, "r") |
gt_text = f.readlines()[0] |
f.close() |
for idx, noise_name in enumerate(noise_names): |
if is_valid[file_name] == 0: |
continue |
noise_type = noise_types[idx] |
noise_wav_files = NOISE[noise_type] |
noise_type_len = len(noise_wav_files) |
noise_index = -1 |
for noise_idx in range(noise_type_len): |
noise_wav_file = noise_wav_files[noise_idx] |
noise_temp_name = noise_wav_file.split("/")[-1].split(".")[0] |
if noise_name != noise_temp_name: |
continue |
noise_index = noise_idx |
noise_wav_file = noise_wav_files[noise_index] |
word_error_rate, character_error_rate, is_valid = test_WER(sorted(AV_RESOURCES.keys())[0], vid, gt_text, |
noise_type, model_name, noise_name, |
noise_wav_file, output_path, file_name, is_valid) |
if is_valid[file_name] == 1: |
for snr_idx in range(len(snr_values)): |
total_word_error_rate[model_name][noise_name][snr_values[snr_idx]] += word_error_rate[snr_idx] |
total_character_error_rate[model_name][noise_name][snr_values[snr_idx]] += character_error_rate[snr_idx] |
out_filepath = output_path / model_name / file_name |
audio_filepath = out_filepath / "audio.wav" |
video_filepath = out_filepath / "video.mp4" |
lip_video_filepath = out_filepath / "lip_movement.mp4" |
os.remove(audio_filepath) |
os.remove(video_filepath) |
os.remove(lip_video_filepath) |
print(f"sum(is_valid.values) : {sum(is_valid.values())}, len(is_valid) : {len(is_valid)}") |
wer_results[model_name] = {"-7.5":0, "-10":0} |
cer_results[model_name] = {"-7.5":0, "-10":0} |
for snr_value in snr_values: |
for noise_name in noise_names: |
wer_results[model_name][snr_value] += total_word_error_rate[model_name][noise_name][snr_value]/sum(is_valid.values()) |
cer_results[model_name][snr_value] += total_character_error_rate[model_name][noise_name][snr_value]/sum(is_valid.values()) |
wer_results[model_name][snr_value] = wer_results[model_name][snr_value] / len(noise_names) |
cer_results[model_name][snr_value] = cer_results[model_name][snr_value] / len(noise_names) |
with open(wer_path, 'w') as f: |
json.dump(wer_results, f, indent=4) |
f.close() |
with open(cer_path, 'w') as f: |
json.dump(cer_results, f, indent=4) |
f.close() |
with open(is_valid_path, 'w') as f: |
json.dump(is_valid, f, indent=4) |
f.close() |
print(f"{model_name} end.") |
with open(wer_path, 'w') as f: |
json.dump(wer_results, f, indent=4) |
f.close() |
with open(cer_path, 'w') as f: |
json.dump(cer_results, f, indent=4) |
f.close() |
with open(is_valid_path, 'w') as f: |
json.dump(is_valid, f, indent=4) |
f.close() |
print(f"Total end in {time.time()-start_eval} secs.") |