TDN-M's picture
Update app.py
74076e9 verified
raw
history blame
4.76 kB
import streamlit as st
import os
import sys
import torch
import pickle
import numpy
import librosa
import subprocess
from avatar import Avatar
def run_pickleface():
try:
result = subprocess.run(
['python', 'pickleface.py'],
check=True,
capture_output=True,
text=True
)
print(result.stdout)
return True
except subprocess.CalledProcessError as e:
print(f"Error running pickleface.py: {e.stderr}")
return False
def initialize_face_detection_results():
# Kiểm tra xem tất cả file pkl đã tồn tại chưa
missing_files = [opt for opt in options if not os.path.exists(f'ref_videos/{opt}_face_det_result.pkl')]
if missing_files:
current_status_placeholder.write("Creating face detection results...")
if not run_pickleface():
st.error("Failed to create face detection results")
st.stop()
current_status_placeholder.write("Face detection results created successfully!")
# Cấu hình ban đầu
options = ['Aude', 'Kyla', 'Liv', 'MC6']
images = ['ref_videos/Aude.png', 'ref_videos/Kyla.png', 'ref_videos/Liv.png', 'ref_videos/MC6.png']
# Thêm đường dẫn đến thư mục Wav2Lip
wav2lip_path = os.path.join(os.path.dirname(__file__), "Wav2Lip")
if wav2lip_path not in sys.path:
sys.path.insert(0, wav2lip_path)
# Giao diện
big_text = """
<div style='text-align: center;'>
<h1 style='font-size: 30x;'>Text to Speech Synchronized Video</h1>
</div>
"""
st.markdown(big_text, unsafe_allow_html=True)
current_status_placeholder = st.empty()
init_progress_bar = st.progress(0)
# Khởi tạo session state
if 'is_initialized' not in st.session_state:
initialize_face_detection_results()
# Khởi tạo Avatar
st.session_state.avatar = Avatar()
st.session_state.avatar.export_video = False
# Load model
current_status_placeholder.write("Loading model...")
st.session_state.avatar.load_model("checkpoint/wav2lip_gan.pth")
current_status_placeholder.write("Model loaded successfully")
# Cấu hình thiết bị
st.session_state.avatar.device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {st.session_state.avatar.device}")
# Cấu hình đường dẫn
st.session_state.avatar.output_audio_path = "audio/"
st.session_state.avatar.output_audio_filename = "result.wav"
st.session_state.avatar.temp_lip_video_no_voice_path = "temp/"
st.session_state.avatar.temp_lip_video_no_voice_filename = "result.avi"
st.session_state.avatar.output_video_path = "results/"
st.session_state.avatar.output_video_name = "result_voice.mp4"
# Khởi tạo video mặc định
st.session_state.selected_option = "Liv"
st.session_state.avatar.ref_video_path_and_filename = f"ref_videos/{st.session_state.selected_option}.mp4"
# Xử lý video và face detection
st.session_state.avatar.get_video_full_frames(st.session_state.avatar.ref_video_path_and_filename)
st.session_state.avatar.face_detect_batch_size = 16
# Load face detection results cho tất cả options
st.session_state.face_det_results_dict = {}
for option in options:
with open(f'ref_videos/{option}_face_det_result.pkl', 'rb') as file:
st.session_state.face_det_results_dict[option] = pickle.load(file)
st.session_state.avatar.face_detect_img_results = st.session_state.face_det_results_dict[st.session_state.selected_option]
st.session_state.avatar.face_det_results_path_and_name = f'ref_videos/{st.session_state.selected_option}_face_det_result.pkl'
# Xử lý text to speech
input_text = "Hi How are you?"
st.session_state.avatar.text_to_lip_video(input_text, init_progress_bar)
current_status_placeholder.write("Face detection results loaded")
st.session_state['is_initialized'] = True
# Giao diện lựa chọn video
selected_option = st.radio("Choose an option:", options, index=options.index(st.session_state.selected_option))
img_col1, img_col2 = st.columns([1,1])
with img_col1:
st.image(images[options.index(selected_option)])
# Xử lý khi thay đổi lựa chọn video
if st.session_state.selected_option != selected_option:
print("The selected option has changed!")
st.session_state.selected_option = selected_option
st.session_state.avatar.ref_video_path_and_filename = f"ref_videos/{st.session_state.selected_option}.mp4"
st.session_state.avatar.get_video_full_frames(st.session_state.avatar.ref_video_path_and_filename)
st.session_state.avatar.face_detect_img_results = st.session_state.face_det_results_dict[st.session_state.selected_option]