File size: 6,547 Bytes
e34aada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import os
os.environ["OMP_NUM_THREADS"] = "1"
import sys
import glob
import cv2
import pickle
import tqdm
import numpy as np
import mediapipe as mp
from utils.commons.multiprocess_utils import multiprocess_run_tqdm
from utils.commons.os_utils import multiprocess_glob
from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker
import warnings
import traceback

warnings.filterwarnings('ignore')

"""
基于Face_aligment的lm68已被弃用,因为其:
1. 对眼睛部位的预测精度极低
2. 无法在大偏转角度时准确预测被遮挡的下颚线, 导致大角度时3dmm的GT label就是有问题的, 从而影响性能
我们目前转而使用基于mediapipe的lm68
"""
# def extract_landmarks(ori_imgs_dir):

#     print(f'[INFO] ===== extract face landmarks from {ori_imgs_dir} =====')

#     fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)
#     image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.png'))
#     for image_path in tqdm.tqdm(image_paths):
#         out_name = image_path.replace("/images_512/", "/lms_2d/").replace(".png",".lms")
#         if os.path.exists(out_name):
#             continue
#         input = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
#         input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
#         preds = fa.get_landmarks(input)
#         if preds is None:
#             print(f"Skip {image_path} for no face detected")
#             continue
#         if len(preds) > 0:
#             lands = preds[0].reshape(-1, 2)[:,:2]
#             os.makedirs(os.path.dirname(out_name), exist_ok=True)
#             np.savetxt(out_name, lands, '%f')
#     del fa
#     print(f'[INFO] ===== extracted face landmarks =====')

def save_file(name, content):
    with open(name, "wb") as f:
        pickle.dump(content, f) 
        
def load_file(name):
    with open(name, "rb") as f:
        content = pickle.load(f)
    return content


face_landmarker = None
    
def extract_landmark_job(video_name, nerf=False):
    try:
        if nerf:
            out_name = video_name.replace("/raw/", "/processed/").replace(".mp4","/lms_2d.npy")
        else:
            out_name = video_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy")
        if os.path.exists(out_name):
            # print("out exists, skip...")
            return
        try:
            os.makedirs(os.path.dirname(out_name), exist_ok=True)
        except:
            pass
        global face_landmarker
        if face_landmarker is None:
            face_landmarker = MediapipeLandmarker()
        img_lm478, vid_lm478 = face_landmarker.extract_lm478_from_video_name(video_name)
        lm478 = face_landmarker.combine_vid_img_lm478_to_lm478(img_lm478, vid_lm478)
        np.save(out_name, lm478)
        return True
        # print("Hahaha, solve one item!!!")
    except Exception as e:
        traceback.print_exc()
        return False
        
def out_exist_job(vid_name):
    out_name = vid_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy") 
    if os.path.exists(out_name):
        return None
    else:
        return vid_name
    
def get_todo_vid_names(vid_names):
    if len(vid_names) == 1: # nerf
        return vid_names
    todo_vid_names = []
    for i, res in multiprocess_run_tqdm(out_exist_job, vid_names, num_workers=128):
        if res is not None:
            todo_vid_names.append(res)
    return todo_vid_names

if __name__ == '__main__':
    import argparse, glob, tqdm, random
    parser = argparse.ArgumentParser()
    parser.add_argument("--vid_dir", default='nerf')
    parser.add_argument("--ds_name", default='data/raw/videos/May.mp4')
    parser.add_argument("--num_workers", default=2, type=int)
    parser.add_argument("--process_id", default=0, type=int)
    parser.add_argument("--total_process", default=1, type=int)
    parser.add_argument("--reset", action="store_true")
    parser.add_argument("--load_names", action="store_true")

    args = parser.parse_args()
    vid_dir = args.vid_dir
    ds_name = args.ds_name
    load_names = args.load_names

    if ds_name.lower() == 'nerf': # 处理单个视频
        vid_names = [vid_dir]
        out_names = [video_name.replace("/raw/", "/processed/").replace(".mp4","/lms_2d.npy") for video_name in vid_names]
    else: # 处理整个数据集
        if ds_name in ['lrs3_trainval']:
            vid_name_pattern = os.path.join(vid_dir, "*/*.mp4")
        elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
            vid_name_pattern = os.path.join(vid_dir, "*.mp4")
        elif ds_name in ['lrs2', 'lrs3', 'voxceleb2', 'CMLR']:
            vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
        elif ds_name in ["RAVDESS", 'VFHQ']:
            vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
        else:
            raise NotImplementedError()
        
        vid_names_path = os.path.join(vid_dir, "vid_names.pkl")
        if os.path.exists(vid_names_path) and load_names:
            print(f"loading vid names from {vid_names_path}")
            vid_names = load_file(vid_names_path)
        else:
            vid_names = multiprocess_glob(vid_name_pattern)
        vid_names = sorted(vid_names)
        if not load_names:
            print(f"saving vid names to {vid_names_path}")
            save_file(vid_names_path, vid_names)
        out_names = [video_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy") for video_name in vid_names]

    process_id = args.process_id
    total_process = args.total_process
    if total_process > 1:
        assert process_id <= total_process -1
        num_samples_per_process = len(vid_names) // total_process
        if process_id == total_process:
            vid_names = vid_names[process_id * num_samples_per_process : ]
        else:
            vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
    
    if not args.reset:
        vid_names = get_todo_vid_names(vid_names)
    print(f"todo videos number: {len(vid_names)}")

    fail_cnt = 0
    job_args = [(vid_name, ds_name=='nerf') for vid_name in vid_names]
    for (i, res) in multiprocess_run_tqdm(extract_landmark_job, job_args, num_workers=args.num_workers, desc=f"Root {args.process_id}: extracing MP-based landmark2d"): 
        if res is False:
            fail_cnt += 1
        print(f"finished {i + 1} / {len(vid_names)} = {(i + 1) / len(vid_names):.4f}, failed {fail_cnt} / {i + 1} = {fail_cnt / (i + 1):.4f}")
        sys.stdout.flush()
        pass