|
import os |
|
import sys |
|
sys.path.append(os.getcwd()) |
|
|
|
from glob import glob |
|
|
|
from argparse import ArgumentParser |
|
import json |
|
|
|
from evaluation.util import * |
|
from evaluation.metrics import * |
|
from tqdm import tqdm |
|
|
|
parser = ArgumentParser() |
|
parser.add_argument('--speaker', required=True, type=str) |
|
parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str) |
|
args = parser.parse_args() |
|
|
|
speaker = args.speaker |
|
test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker))) |
|
|
|
precision_list=[] |
|
recall_list=[] |
|
accuracy_list=[] |
|
|
|
for aud in tqdm(test_audios): |
|
base_name = os.path.splitext(aud)[0] |
|
gt_path = get_full_path(aud, speaker, 'val') |
|
_, gt_poses, _ = get_gts(gt_path) |
|
if gt_poses.shape[0] < 50: |
|
continue |
|
gt_poses = gt_poses[np.newaxis,...] |
|
|
|
for post_fix in args.post_fix: |
|
pred_path = base_name + '_'+post_fix+'.json' |
|
pred_poses = np.array(json.load(open(pred_path))) |
|
|
|
pred_poses = cvt25(pred_poses, gt_poses) |
|
|
|
|
|
gt_valid_points = valid_points(gt_poses) |
|
pred_valid_points = valid_points(pred_poses) |
|
|
|
|
|
|
|
gt_mode_transition_seq = mode_transition_seq(gt_valid_points, speaker) |
|
pred_mode_transition_seq = mode_transition_seq(pred_valid_points, speaker) |
|
|
|
|
|
|
|
precision, recall, accuracy = mode_transition_consistency(pred_mode_transition_seq, gt_mode_transition_seq) |
|
precision_list.append(precision) |
|
recall_list.append(recall) |
|
accuracy_list.append(accuracy) |
|
print(len(precision_list), len(recall_list), len(accuracy_list)) |
|
precision_list = np.mean(precision_list) |
|
recall_list = np.mean(recall_list) |
|
accuracy_list = np.mean(accuracy_list) |
|
|
|
print('precision, recall, accu:', precision_list, recall_list, accuracy_list) |
|
|