Spaces:
Runtime error
Runtime error
import threading | |
from pathlib import Path | |
from nota_wav2lip.demo import Wav2LipModelComparisonDemo | |
class Wav2LipModelComparisonGradio(Wav2LipModelComparisonDemo): | |
def __init__( | |
self, | |
device='cpu', | |
result_dir='./temp', | |
video_label_dict=None, | |
audio_label_list=None, | |
default_video='v1', | |
default_audio='a1' | |
) -> None: | |
if audio_label_list is None: | |
audio_label_list = {} | |
if video_label_dict is None: | |
video_label_dict = {} | |
super().__init__(device, result_dir) | |
self._video_label_dict = {k: Path(v).with_suffix('.mp4') for k, v in video_label_dict.items()} | |
self._audio_label_dict = audio_label_list | |
self._default_video = default_video | |
self._default_audio = default_audio | |
self._lock = threading.Lock() # lock for asserting that concurrency_count == 1 | |
def _is_valid_input(self, video_selection, audio_selection): | |
assert video_selection in self._video_label_dict, \ | |
f"Your input ({video_selection}) is not in {self._video_label_dict}!!!" | |
assert audio_selection in self._audio_label_dict, \ | |
f"Your input ({audio_selection}) is not in {self._audio_label_dict}!!!" | |
def generate_original_model(self, video_selection, audio_selection): | |
try: | |
self._is_valid_input(video_selection, audio_selection) | |
with self._lock: | |
output_video_path, inference_time, inference_fps = \ | |
self.save_as_video(audio_name=audio_selection, | |
video_name=video_selection, | |
model_type='wav2lip') | |
return str(output_video_path), format(inference_time, ".2f"), format(inference_fps, ".1f") | |
except KeyboardInterrupt: | |
exit() | |
except Exception as e: | |
print(e) | |
pass | |
def generate_compressed_model(self, video_selection, audio_selection): | |
try: | |
self._is_valid_input(video_selection, audio_selection) | |
with self._lock: | |
output_video_path, inference_time, inference_fps = \ | |
self.save_as_video(audio_name=audio_selection, | |
video_name=video_selection, | |
model_type='nota_wav2lip') | |
return str(output_video_path), format(inference_time, ".2f"), format(inference_fps, ".1f") | |
except KeyboardInterrupt: | |
exit() | |
except Exception as e: | |
print(e) | |
pass | |
def switch_video_samples(self, video_selection): | |
try: | |
if video_selection not in self._video_label_dict: | |
return self._video_label_dict[self._default_video] | |
return self._video_label_dict[video_selection] | |
except KeyboardInterrupt: | |
exit() | |
except Exception as e: | |
print(e) | |
pass | |
def switch_audio_samples(self, audio_selection): | |
try: | |
if audio_selection not in self._audio_label_dict: | |
return self._audio_label_dict[self._default_audio] | |
return self._audio_label_dict[audio_selection] | |
except KeyboardInterrupt: | |
exit() | |
except Exception as e: | |
print(e) | |
pass | |