import os os.environ["OMP_NUM_THREADS"] = "1" import sys import glob import cv2 import tqdm import numpy as np from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker from utils.commons.multiprocess_utils import multiprocess_run_tqdm import warnings warnings.filterwarnings('ignore') import random random.seed(42) import pickle import json import gzip from typing import Any def load_file(filename, is_gzip: bool = False, is_json: bool = False) -> Any: if is_json: if is_gzip: with gzip.open(filename, "r", encoding="utf-8") as f: loaded_object = json.load(f) return loaded_object else: with open(filename, "r", encoding="utf-8") as f: loaded_object = json.load(f) return loaded_object else: if is_gzip: with gzip.open(filename, "rb") as f: loaded_object = pickle.load(f) return loaded_object else: with open(filename, "rb") as f: loaded_object = pickle.load(f) return loaded_object def save_file(filename, content, is_gzip: bool = False, is_json: bool = False) -> None: if is_json: if is_gzip: with gzip.open(filename, "w", encoding="utf-8") as f: json.dump(content, f) else: with open(filename, "w", encoding="utf-8") as f: json.dump(content, f) else: if is_gzip: with gzip.open(filename, "wb") as f: pickle.dump(content, f) else: with open(filename, "wb") as f: pickle.dump(content, f) face_landmarker = None def extract_lms_mediapipe_job(img): if img is None: return None global face_landmarker if face_landmarker is None: face_landmarker = MediapipeLandmarker() lm478 = face_landmarker.extract_lm478_from_img(img) return lm478 def extract_landmark_job(img_name): try: # if img_name == 'datasets/PanoHeadGen/raw/images/multi_view/chunk_0/seed0000002.png': # print(1) # input() out_name = img_name.replace("/images_512/", "/lms_2d/").replace(".png","_lms.npy") if os.path.exists(out_name): print("out exists, skip...") return try: os.makedirs(os.path.dirname(out_name), exist_ok=True) except: pass img = cv2.imread(img_name)[:,:,::-1] if img is not None: lm468 = extract_lms_mediapipe_job(img) if lm468 is not None: np.save(out_name, lm468) # print("Hahaha, solve one item!!!") except Exception as e: print(e) pass def out_exist_job(img_name): out_name = img_name.replace("/images_512/", "/lms_2d/").replace(".png","_lms.npy") if os.path.exists(out_name): return None else: return img_name # def get_todo_img_names(img_names): # todo_img_names = [] # for i, res in multiprocess_run_tqdm(out_exist_job, img_names, num_workers=64): # if res is not None: # todo_img_names.append(res) # return todo_img_names if __name__ == '__main__': import argparse, glob, tqdm, random parser = argparse.ArgumentParser() parser.add_argument("--img_dir", default='/home/tiger/datasets/raw/FFHQ/images_512/') parser.add_argument("--ds_name", default='FFHQ') parser.add_argument("--num_workers", default=64, type=int) parser.add_argument("--process_id", default=0, type=int) parser.add_argument("--total_process", default=1, type=int) parser.add_argument("--reset", action='store_true') parser.add_argument("--img_names_file", default="img_names.pkl", type=str) parser.add_argument("--load_img_names", action="store_true") args = parser.parse_args() print(f"args {args}") img_dir = args.img_dir img_names_file = os.path.join(img_dir, args.img_names_file) if args.load_img_names: img_names = load_file(img_names_file) print(f"load image names from {img_names_file}") else: if args.ds_name == 'FFHQ_MV': img_name_pattern1 = os.path.join(img_dir, "ref_imgs/*.png") img_names1 = glob.glob(img_name_pattern1) img_name_pattern2 = os.path.join(img_dir, "mv_imgs/*.png") img_names2 = glob.glob(img_name_pattern2) img_names = img_names1 + img_names2 img_names = sorted(img_names) elif args.ds_name == 'FFHQ': img_name_pattern = os.path.join(img_dir, "*.png") img_names = glob.glob(img_name_pattern) img_names = sorted(img_names) elif args.ds_name == "PanoHeadGen": # img_name_patterns = ["ref/*/*.png", "multi_view/*/*.png", "reverse/*/*.png"] img_name_patterns = ["ref/*/*.png"] img_names = [] for img_name_pattern in img_name_patterns: img_name_pattern_full = os.path.join(img_dir, img_name_pattern) img_names_part = glob.glob(img_name_pattern_full) img_names.extend(img_names_part) img_names = sorted(img_names) # save image names if not args.load_img_names: save_file(img_names_file, img_names) print(f"save image names in {img_names_file}") print(f"total images number: {len(img_names)}") process_id = args.process_id total_process = args.total_process if total_process > 1: assert process_id <= total_process -1 num_samples_per_process = len(img_names) // total_process if process_id == total_process: img_names = img_names[process_id * num_samples_per_process : ] else: img_names = img_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process] # if not args.reset: # img_names = get_todo_img_names(img_names) print(f"todo_image {img_names[:10]}") print(f"processing images number in this process: {len(img_names)}") # print(f"todo images number: {len(img_names)}") # input() # exit() if args.num_workers == 1: index = 0 for img_name in tqdm.tqdm(img_names, desc=f"Root process {args.process_id}: extracting MP-based landmark2d"): try: extract_landmark_job(img_name) except Exception as e: print(e) pass if index % max(1, int(len(img_names) * 0.003)) == 0: print(f"processed {index} / {len(img_names)}") sys.stdout.flush() index += 1 else: for i, res in multiprocess_run_tqdm( extract_landmark_job, img_names, num_workers=args.num_workers, desc=f"Root {args.process_id}: extracing MP-based landmark2d"): # if index % max(1, int(len(img_names) * 0.003)) == 0: print(f"processed {i+1} / {len(img_names)}") sys.stdout.flush() print(f"Root {args.process_id}: Finished extracting.")