|
|
|
|
|
|
|
|
|
|
|
import os |
|
import cv2 |
|
import sys |
|
import time |
|
import random |
|
import shutil |
|
import hashlib |
|
import logging |
|
import argparse |
|
|
|
from tqdm import tqdm |
|
from pathlib import Path |
|
from ffmpy import FFmpeg |
|
import glob |
|
import pdb |
|
import torchaudio |
|
import random |
|
import torch |
|
import numpy as np |
|
from scipy.io import wavfile |
|
from jiwer import wer, cer |
|
import json |
|
from faster_whisper import WhisperModel |
|
import shutil |
|
|
|
random_seed=1234 |
|
torch.manual_seed(random_seed) |
|
torch.cuda.manual_seed(random_seed) |
|
torch.cuda.manual_seed_all(random_seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
np.random.seed(random_seed) |
|
random.seed(random_seed) |
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
from demo_utils import *sh scripts/demo.sh multi |
|
from utils import ( |
|
split_video_to_frames, |
|
resize_frames, |
|
crop_patch, |
|
save_video, |
|
) |
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
logger = logging.getLogger(__name__) |
|
|
|
def detect_landmark(image): |
|
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) |
|
rects = DETECTOR(gray, 1) |
|
coords = None |
|
for (_, rect) in enumerate(rects): |
|
shape = PREDICTOR(gray, rect) |
|
coords = np.zeros((68, 2), dtype=np.int32) |
|
for i in range(0, 68): |
|
coords[i] = (shape.part(i).x, shape.part(i).y) |
|
return coords |
|
|
|
|
|
@track_time |
|
def extract_lip_movement( |
|
webcam_video, |
|
in_video_filepath, |
|
out_lip_filepath, |
|
num_workers |
|
): |
|
|
|
def copy_video_if_ready(webcam_video, out_path): |
|
with open(webcam_video, 'rb') as fin: |
|
curr_md5hash = hashlib.md5(fin.read()).hexdigest() |
|
|
|
if curr_md5hash in VIDEOS_CACHE: |
|
dst_path = VIDEOS_CACHE[curr_md5hash] |
|
|
|
shutil.copy(dst_path / "video.mp4", out_path) |
|
shutil.copy(dst_path / "lip_movement.mp4", out_path) |
|
shutil.copy(dst_path / "raw_video.md5", out_path) |
|
return True |
|
else: |
|
VIDEOS_CACHE[curr_md5hash] = out_path |
|
with open(out_path / "raw_video.md5", 'w') as fout: |
|
fout.write(curr_md5hash) |
|
return False |
|
''' |
|
if copy_video_if_ready(webcam_video, in_video_filepath.parent): |
|
logger.info("Skip video processing; Loading the cached one!!") |
|
return |
|
''' |
|
|
|
logger.info("Adjust video framerate to 25") |
|
if not os.path.isfile(in_video_filepath): |
|
FFmpeg( |
|
inputs={webcam_video: None}, |
|
outputs={in_video_filepath: "-v quiet -filter:v fps=fps=25 -vf scale=640:480"}, |
|
).run() |
|
|
|
logger.info("Converting video into frames") |
|
frames = list(split_video_to_frames(in_video_filepath)) |
|
|
|
|
|
logger.info("Extract face landmarks from video frames") |
|
landmarks = [ |
|
detect_landmark(frame) |
|
for frame in tqdm(frames, desc="Detecting Lip Movement") |
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
invalid_landmarks_ratio = sum(lnd is None for lnd in landmarks) / len(landmarks) |
|
logger.info(f"Current invalid frame ratio ({invalid_landmarks_ratio}) ") |
|
if invalid_landmarks_ratio > MAX_MISSING_FRAMES_RATIO: |
|
logging.info( |
|
"Invalid frame ratio exceeded maximum allowed ratio!! " + |
|
"Starting resizing the recorded video!!" |
|
) |
|
sequence = resize_frames(frames) |
|
else: |
|
|
|
if invalid_landmarks_ratio != 0: |
|
logger.info("Linearly-interpolate invalid landmarks") |
|
continuous_landmarks = landmarks_interpolate(landmarks) |
|
else: |
|
continuous_landmarks = landmarks |
|
|
|
logger.info("Cropping the mouth region.") |
|
sequence = crop_patch( |
|
frames, |
|
len(frames), |
|
continuous_landmarks, |
|
MEAN_FACE_LANDMARKS, |
|
) |
|
|
|
save_video(sequence, out_lip_filepath, fps=25) |
|
|
|
|
|
def process_input_video( |
|
model_type: str, |
|
input_video_path: str, |
|
noise_snr: int, |
|
noise_type: str, |
|
outpath: str, |
|
): |
|
if input_video_path is None: |
|
raise IOError( |
|
"Gradio didn't record the video. Refresh the web page, please!!" |
|
) |
|
|
|
audio_filepath = outpath / "audio.wav" |
|
video_filepath = outpath / "video.mp4" |
|
noisy_audio_filepath = outpath / "noisy_audio.wav" |
|
lip_video_filepath = outpath / "lip_movement.mp4" |
|
|
|
if not os.path.isfile(video_filepath) and not os.path.isfile(lip_video_filepath): |
|
|
|
extract_lip_movement( |
|
input_video_path, video_filepath, lip_video_filepath, |
|
num_workers=min(os.cpu_count(), 5) |
|
) |
|
|
|
|
|
logger.info(f"Mixing audio with `{noise_type}` noise (SNR={noise_snr}).") |
|
noise_wav_files = NOISE[noise_type] |
|
noise_wav_file = noise_wav_files[random.randint(0, len(noise_wav_files) - 1)] |
|
logger.debug(f"Noise Wav used is {noise_wav_file}") |
|
mixed = mix_audio_with_noise( |
|
input_video_path, audio_filepath, noisy_audio_filepath, |
|
noise_wav_file, noise_snr |
|
) |
|
|
|
|
|
logger.info("Adding noisy audio with the lip-movement video.") |
|
noisy_lip_filepath = outpath / "noisy_lip_movement.mp4" |
|
FFmpeg( |
|
inputs={noisy_audio_filepath: None, lip_video_filepath: None}, |
|
outputs={noisy_lip_filepath: "-v quiet -c:v copy -c:a aac"}, |
|
).run() |
|
|
|
|
|
av_text = infer_av_hubert( |
|
AV_RESOURCES[model_type]["model"], |
|
AV_RESOURCES[model_type]["task"], |
|
AV_RESOURCES[model_type]["generator"], |
|
lip_video_filepath, |
|
noisy_audio_filepath, |
|
duration=len(mixed) / 16000 |
|
) |
|
logger.info(f"Av-HuBERT Output: {av_text}") |
|
|
|
logger.info("Summary:") |
|
for k, v in TIME_TRACKER.items(): |
|
logger.info(f'Function {k} executed in {v} seconds') |
|
logger.info(30 * '=' + " Done! " + '=' * 30) |
|
return (str(noisy_lip_filepath), av_text) |
|
|
|
def test_WER( |
|
model_type: str, |
|
input_video_path: str, |
|
gt_text: str, |
|
noise_type: str, |
|
model_name: str, |
|
noise_name : str, |
|
noise_wav_file : str, |
|
outpath: str, |
|
file_name: str, |
|
is_valid: dict, |
|
): |
|
if input_video_path is None: |
|
raise IOError( |
|
"Gradio didn't record the video. Refresh the web page, please!!" |
|
) |
|
out_filepath = outpath / model_name/file_name |
|
out_filepath.mkdir(parents=True, exist_ok=True) |
|
audio_filepath = out_filepath/ "audio.wav" |
|
video_filepath = out_filepath/ "video.mp4" |
|
noisy_audio_path = outpath / model_name / noise_type / noise_name |
|
noisy_audio_path.mkdir(parents=True, exist_ok=True) |
|
|
|
lip_video_filepath = out_filepath / "lip_movement.mp4" |
|
if not os.path.isfile(lip_video_filepath): |
|
|
|
extract_lip_movement( |
|
input_video_path, video_filepath, lip_video_filepath, |
|
num_workers=min(os.cpu_count(), 5) |
|
) |
|
|
|
|
|
|
|
|
|
if not os.path.isfile(audio_filepath): |
|
FFmpeg( |
|
inputs={input_video_path: None}, |
|
outputs={audio_filepath: "-v quiet -vn -acodec pcm_s16le -ar 16000 -ac 1"}, |
|
).run() |
|
|
|
sr, audio = wavfile.read(audio_filepath) |
|
_, noise = wavfile.read(noise_wav_file) |
|
|
|
wer_temp = [] |
|
cer_temp = [] |
|
''' |
|
## original wer and edit distance |
|
origin_av_text = infer_av_hubert( |
|
AV_RESOURCES[model_type]["model"], |
|
AV_RESOURCES[model_type]["task"], |
|
AV_RESOURCES[model_type]["generator"], |
|
lip_video_filepath, |
|
audio_filepath, |
|
duration=len(audio) / 16000 |
|
) |
|
word_error_rate = wer(gt_text.lower().replace('\n', ''), origin_av_text.lower().replace('\n', '')) |
|
character_error_rate = cer(gt_text.lower().replace('\n', ''), origin_av_text.lower().replace('\n', '')) |
|
wer_temp.append(word_error_rate) |
|
cer_temp.append(character_error_rate) |
|
''' |
|
for ns in [-7.5, -10]: |
|
|
|
snr_name = "snr_"+ str(ns) |
|
noisy_audio_ns_path = noisy_audio_path / snr_name / file_name |
|
noisy_audio_ns_path.mkdir(parents=True, exist_ok=True) |
|
noisy_audio_ns_filepath = noisy_audio_ns_path / "noisy_audio.wav" |
|
mixed = add_noise(audio, noise, ns) |
|
if not os.path.isfile(noisy_audio_ns_filepath): |
|
wavfile.write(noisy_audio_ns_filepath, sr, mixed) |
|
|
|
|
|
noisy_lip_filepath = noisy_audio_ns_path / "noisy_lip_movement.mp4" |
|
if not os.path.isfile(noisy_lip_filepath): |
|
FFmpeg( |
|
inputs={noisy_audio_ns_filepath: None, lip_video_filepath: None}, |
|
outputs={noisy_lip_filepath: "-v quiet -c:v copy -c:a aac"}, |
|
).run() |
|
|
|
av_text = infer_av_hubert( |
|
AV_RESOURCES[model_type]["model"], |
|
AV_RESOURCES[model_type]["task"], |
|
AV_RESOURCES[model_type]["generator"], |
|
lip_video_filepath, |
|
noisy_audio_ns_filepath, |
|
duration=len(mixed) / 16000 |
|
) |
|
av_text = av_text.replace('.','').replace(',','').replace('!','').replace(';','').replace(':','').replace('?','').replace('/','').lower().replace('\n', '').strip() |
|
gt_text = gt_text.replace('.','').replace(',','').replace('!','').replace(';','').replace(':','').replace('?','').replace('/','').lower().replace('\n', '').strip() |
|
word_error_rate = wer(gt_text, av_text) |
|
character_error_rate = cer(gt_text, av_text) |
|
print(f"av_text : {av_text}") |
|
print(f"gt_text : {gt_text}") |
|
''' |
|
if sum(is_valid.values())>=51 and word_error_rate >= 1.0 and ns == -7.5 and (model_name in ["MultiTalk"]): |
|
is_valid[file_name] = 0 |
|
''' |
|
print( |
|
f"file_name : {file_name}, snr: {str(ns)}, word_error_rate : {word_error_rate}, character_error_rate : {character_error_rate}, is_valid[file_name] : {is_valid[file_name]}") |
|
wer_temp.append(word_error_rate) |
|
cer_temp.append(character_error_rate) |
|
shutil.rmtree(noisy_audio_ns_path) |
|
return wer_temp, cer_temp, is_valid |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--avhubert-path", type=Path, required=False, default="./av_hubert/avhubert", |
|
help="Relative/Absolute path where avhubert repo is located." |
|
) |
|
parser.add_argument( |
|
"--work-dir", type=Path, required=True, |
|
default="/local_data_2/chaeyeon/interspeech2024/avlr", |
|
help="work directory for avlr evaluation" |
|
) |
|
parser.add_argument( |
|
"--language", type=str, required=True, |
|
default="English", |
|
help="evaluation language" |
|
) |
|
parser.add_argument( |
|
"--model-name", type=str, required=True, |
|
default="all", |
|
help="model name" |
|
) |
|
parser.add_argument( |
|
"--exp-name", type=str, required=True, |
|
default="base", |
|
help="experiment name" |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
logger.info("Loading noise samples..") |
|
start_time = time.time() |
|
work_path = args.work_dir / args.language |
|
input_path = work_path / "inputs" |
|
output_path = work_path / "outputs" |
|
|
|
lang_map = {'Arabic': 'ar', 'English': 'en', 'German': 'de', 'Italian': 'it', 'Portuguese': 'pt', 'Spanish': 'es', |
|
'French': 'fr', 'Greek': 'el', 'Russian': 'ru'} |
|
checkpoint_path = work_path / "checkpoints" |
|
|
|
av_model_path = os.path.join(checkpoint_path , lang_map[args.language]+"_avsr") |
|
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
noise_path = args.work_dir / "noise_samples" |
|
NOISE = load_noise_samples(noise_path) |
|
|
|
logger.info("Loading AV models!") |
|
if not checkpoint_path.exists(): |
|
raise ValueError( |
|
f"av-models-path: `{checkpoint_path}` doesn't exist!!" |
|
) |
|
utils.import_user_module( |
|
argparse.Namespace(user_dir=str(args.avhubert_path)) |
|
) |
|
AV_RESOURCES = load_av_models(checkpoint_path) |
|
|
|
logger.info("Loading models responsible for preprocessing!") |
|
metadata_path = args.work_dir / "metadata" |
|
DETECTOR, PREDICTOR, MEAN_FACE_LANDMARKS = ( |
|
load_needed_models_for_lip_movement(metadata_path) |
|
) |
|
logger.info("Done loading!") |
|
|
|
|
|
VIDEOS_CACHE = {} |
|
logger.info("Caching previously recorded videos!") |
|
for hash_path in output_path.rglob("*.md5"): |
|
with open(hash_path) as fin: |
|
md5hash = fin.read() |
|
VIDEOS_CACHE[md5hash] = hash_path.parent |
|
|
|
|
|
if args.model_name == "all": |
|
model_names = ["MultiTalk", "VOCA", "FaceFormer", "CodeTalker"] |
|
|
|
|
|
else: |
|
model_names = [args.model_name] |
|
|
|
wav_path = input_path / "wav" |
|
|
|
model = WhisperModel("large-v3", device="cpu", compute_type="int8") |
|
|
|
wer_path = work_path / f'wer_{args.exp_name}.json' |
|
cer_path = work_path / f'cer_{args.exp_name}.json' |
|
is_valid_path = work_path / f'is_valid_{args.exp_name}.json' |
|
noise_types = ["indoors", "indoors", "music"] |
|
noise_names = ["dog-playing", "kids-playing", "leave_it_to_the_experts"] |
|
snr_values = ['-7.5', '-10'] |
|
total_word_error_rate = {} |
|
total_character_error_rate = {} |
|
wer_results = {} |
|
cer_results = {} |
|
|
|
video_path = input_path / model_names[0] |
|
video_lists = glob.glob(os.path.join(video_path, "*.mp4")) |
|
sorted_video_lists = sorted(video_lists) |
|
|
|
text_path = input_path / "text" |
|
text_path.mkdir(parents=True, exist_ok=True) |
|
text_lists = glob.glob(os.path.join(text_path, "*.txt")) |
|
if len(text_lists) != len(video_lists): |
|
for vid in sorted_video_lists: |
|
file_name = vid.split("/")[-1].split(".")[0] |
|
|
|
wav_file = os.path.join(wav_path, file_name + ".wav") |
|
segments, info = model.transcribe(audio=wav_file, language=lang_map[args.language], |
|
beam_size=5) |
|
text = '' |
|
for segment in segments: |
|
text = text + segment.text |
|
|
|
text_file = os.path.join(text_path, file_name + ".txt") |
|
with open(text_file, 'w') as f: |
|
f.write(text.replace('.','').replace(',','').replace('!','').replace(';','').replace(':','').replace('?','').strip()) |
|
f.close() |
|
|
|
start_eval = time.time() |
|
print(f"Pseudo gt text made in {start_eval - start_time} secs.") |
|
|
|
if args.language in ['Greek', 'Italian']: |
|
is_valid_path = work_path / f'is_valid_base_wo_self.json' |
|
elif args.language in ['English', 'French', 'German']: |
|
is_valid_path = work_path / f'is_valid_base.json' |
|
with open(is_valid_path, 'r') as f: |
|
is_valid = json.load(f) |
|
f.close() |
|
''' |
|
for vid in sorted_video_lists : |
|
file_name = vid.split("/")[-1].split(".")[0] |
|
is_valid[file_name] = 1 |
|
''' |
|
for model_name in model_names: |
|
total_word_error_rate[model_name] = {} |
|
total_character_error_rate[model_name] = {} |
|
for noise_name in noise_names: |
|
total_word_error_rate[model_name][noise_name]={"-7.5":0.0, "-10":0.0} |
|
total_character_error_rate[model_name][noise_name] = {"-7.5": 0.0, "-10": 0.0} |
|
video_path = input_path / model_name |
|
video_lists = glob.glob(os.path.join(video_path, "*.mp4")) |
|
sorted_video_lists = sorted(video_lists) |
|
for vid in sorted_video_lists: |
|
file_name = vid.split("/")[-1].split(".")[0] |
|
if args.language == "French": |
|
file_name=file_name.replace('F','f',1) |
|
elif args.language == "English": |
|
file_name = file_name.replace('E', 'e', 1) |
|
elif args.language == "Italian": |
|
file_name = file_name.replace('I', 'i', 1) |
|
elif args.language == "Greek": |
|
file_name = file_name.replace('G', 'g', 1) |
|
if is_valid[file_name] == 0: |
|
continue |
|
text_file = os.path.join(text_path, file_name + ".txt") |
|
f = open(text_file, "r") |
|
gt_text = f.readlines()[0] |
|
f.close() |
|
for idx, noise_name in enumerate(noise_names): |
|
if is_valid[file_name] == 0: |
|
continue |
|
noise_type = noise_types[idx] |
|
noise_wav_files = NOISE[noise_type] |
|
noise_type_len = len(noise_wav_files) |
|
noise_index = -1 |
|
for noise_idx in range(noise_type_len): |
|
noise_wav_file = noise_wav_files[noise_idx] |
|
noise_temp_name = noise_wav_file.split("/")[-1].split(".")[0] |
|
if noise_name != noise_temp_name: |
|
continue |
|
noise_index = noise_idx |
|
|
|
noise_wav_file = noise_wav_files[noise_index] |
|
word_error_rate, character_error_rate, is_valid = test_WER(sorted(AV_RESOURCES.keys())[0], vid, gt_text, |
|
noise_type, model_name, noise_name, |
|
noise_wav_file, output_path, file_name, is_valid) |
|
if is_valid[file_name] == 1: |
|
for snr_idx in range(len(snr_values)): |
|
total_word_error_rate[model_name][noise_name][snr_values[snr_idx]] += word_error_rate[snr_idx] |
|
total_character_error_rate[model_name][noise_name][snr_values[snr_idx]] += character_error_rate[snr_idx] |
|
out_filepath = output_path / model_name / file_name |
|
audio_filepath = out_filepath / "audio.wav" |
|
video_filepath = out_filepath / "video.mp4" |
|
lip_video_filepath = out_filepath / "lip_movement.mp4" |
|
os.remove(audio_filepath) |
|
os.remove(video_filepath) |
|
os.remove(lip_video_filepath) |
|
print(f"sum(is_valid.values) : {sum(is_valid.values())}, len(is_valid) : {len(is_valid)}") |
|
wer_results[model_name] = {"-7.5":0, "-10":0} |
|
cer_results[model_name] = {"-7.5":0, "-10":0} |
|
|
|
for snr_value in snr_values: |
|
for noise_name in noise_names: |
|
wer_results[model_name][snr_value] += total_word_error_rate[model_name][noise_name][snr_value]/sum(is_valid.values()) |
|
cer_results[model_name][snr_value] += total_character_error_rate[model_name][noise_name][snr_value]/sum(is_valid.values()) |
|
wer_results[model_name][snr_value] = wer_results[model_name][snr_value] / len(noise_names) |
|
cer_results[model_name][snr_value] = cer_results[model_name][snr_value] / len(noise_names) |
|
with open(wer_path, 'w') as f: |
|
json.dump(wer_results, f, indent=4) |
|
f.close() |
|
with open(cer_path, 'w') as f: |
|
json.dump(cer_results, f, indent=4) |
|
f.close() |
|
with open(is_valid_path, 'w') as f: |
|
json.dump(is_valid, f, indent=4) |
|
f.close() |
|
print(f"{model_name} end.") |
|
|
|
with open(wer_path, 'w') as f: |
|
json.dump(wer_results, f, indent=4) |
|
f.close() |
|
with open(cer_path, 'w') as f: |
|
json.dump(cer_results, f, indent=4) |
|
f.close() |
|
with open(is_valid_path, 'w') as f: |
|
json.dump(is_valid, f, indent=4) |
|
f.close() |
|
print(f"Total end in {time.time()-start_eval} secs.") |