|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
import os
|
|
import tqdm
|
|
from statistics import fmean
|
|
from eval.syncnet import SyncNetEval
|
|
from eval.syncnet_detect import SyncNetDetector
|
|
from latentsync.utils.util import red_text
|
|
import torch
|
|
|
|
|
|
def syncnet_eval(syncnet, syncnet_detector, video_path, temp_dir, detect_results_dir="detect_results"):
|
|
syncnet_detector(video_path=video_path, min_track=50)
|
|
crop_videos = os.listdir(os.path.join(detect_results_dir, "crop"))
|
|
if crop_videos == []:
|
|
raise Exception(red_text(f"Face not detected in {video_path}"))
|
|
av_offset_list = []
|
|
conf_list = []
|
|
for video in crop_videos:
|
|
av_offset, _, conf = syncnet.evaluate(
|
|
video_path=os.path.join(detect_results_dir, "crop", video), temp_dir=temp_dir
|
|
)
|
|
av_offset_list.append(av_offset)
|
|
conf_list.append(conf)
|
|
av_offset = int(fmean(av_offset_list))
|
|
conf = fmean(conf_list)
|
|
print(f"Input video: {video_path}\nSyncNet confidence: {conf:.2f}\nAV offset: {av_offset}")
|
|
return av_offset, conf
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="SyncNet")
|
|
parser.add_argument("--initial_model", type=str, default="checkpoints/auxiliary/syncnet_v2.model", help="")
|
|
parser.add_argument("--video_path", type=str, default=None, help="")
|
|
parser.add_argument("--videos_dir", type=str, default="/root/processed")
|
|
parser.add_argument("--temp_dir", type=str, default="temp", help="")
|
|
|
|
args = parser.parse_args()
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
syncnet = SyncNetEval(device=device)
|
|
syncnet.loadParameters(args.initial_model)
|
|
|
|
syncnet_detector = SyncNetDetector(device=device, detect_results_dir="detect_results")
|
|
|
|
if args.video_path is not None:
|
|
syncnet_eval(syncnet, syncnet_detector, args.video_path, args.temp_dir)
|
|
else:
|
|
sync_conf_list = []
|
|
video_names = sorted([f for f in os.listdir(args.videos_dir) if f.endswith(".mp4")])
|
|
for video_name in tqdm.tqdm(video_names):
|
|
try:
|
|
_, conf = syncnet_eval(
|
|
syncnet, syncnet_detector, os.path.join(args.videos_dir, video_name), args.temp_dir
|
|
)
|
|
sync_conf_list.append(conf)
|
|
except Exception as e:
|
|
print(e)
|
|
print(f"The average sync confidence is {fmean(sync_conf_list):.02f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|