File size: 4,752 Bytes
e34aada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
os.environ["OMP_NUM_THREADS"] = "1"

import glob
import cv2
import tqdm
import numpy as np
import PIL
from utils.commons.tensor_utils import convert_to_np
import torch
import mediapipe as mp
from utils.commons.multiprocess_utils import multiprocess_run_tqdm
from data_gen.utils.mp_feature_extractors.mp_segmenter import MediapipeSegmenter
from data_gen.utils.process_video.extract_segment_imgs import inpaint_torso_job, extract_background, save_rgb_image_to_path
seg_model = MediapipeSegmenter()


def extract_segment_job(img_name):
    try:
        img = cv2.imread(img_name)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        segmap = seg_model._cal_seg_map(img)
        bg_img = extract_background([img], [segmap])
        out_img_name = img_name.replace("/images_512/",f"/bg_img/").replace(".mp4", ".jpg")
        save_rgb_image_to_path(bg_img, out_img_name)

        com_img = img.copy()
        bg_part = segmap[0].astype(bool)[..., None].repeat(3,axis=-1)
        com_img[bg_part] = bg_img[bg_part]
        out_img_name = img_name.replace("/images_512/",f"/com_imgs/")
        save_rgb_image_to_path(com_img, out_img_name)

        for mode in ['head', 'torso', 'person', 'torso_with_bg', 'bg']:
            out_img, _ = seg_model._seg_out_img_with_segmap(img, segmap, mode=mode)
            out_img_name = img_name.replace("/images_512/",f"/{mode}_imgs/")
            out_img = cv2.cvtColor(out_img, cv2.COLOR_RGB2BGR)
            try: os.makedirs(os.path.dirname(out_img_name), exist_ok=True)
            except: pass
            cv2.imwrite(out_img_name, out_img)

        inpaint_torso_img, inpaint_torso_with_bg_img, _, _ = inpaint_torso_job(img, segmap)
        out_img_name = img_name.replace("/images_512/",f"/inpaint_torso_imgs/")
        save_rgb_image_to_path(inpaint_torso_img, out_img_name)
        inpaint_torso_with_bg_img[bg_part] = bg_img[bg_part]
        out_img_name = img_name.replace("/images_512/",f"/inpaint_torso_with_com_bg_imgs/")
        save_rgb_image_to_path(inpaint_torso_with_bg_img, out_img_name)
        return 0
    except Exception as e:
        print(e)
        return 1

def out_exist_job(img_name):
    out_name1 = img_name.replace("/images_512/", "/head_imgs/")
    out_name2 = img_name.replace("/images_512/", "/com_imgs/")
    out_name3 = img_name.replace("/images_512/", "/inpaint_torso_with_com_bg_imgs/")
    
    if  os.path.exists(out_name1) and os.path.exists(out_name2) and os.path.exists(out_name3):
        return None
    else:
        return img_name

def get_todo_img_names(img_names):
    todo_img_names = []
    for i, res in multiprocess_run_tqdm(out_exist_job, img_names, num_workers=64):
        if res is not None:
            todo_img_names.append(res)
    return todo_img_names


if __name__ == '__main__':
    import argparse, glob, tqdm, random
    parser = argparse.ArgumentParser()
    parser.add_argument("--img_dir", default='./images_512')
    # parser.add_argument("--img_dir", default='/home/tiger/datasets/raw/FFHQ/images_512')
    parser.add_argument("--ds_name", default='FFHQ')
    parser.add_argument("--num_workers", default=1, type=int)
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--process_id", default=0, type=int)
    parser.add_argument("--total_process", default=1, type=int)
    parser.add_argument("--reset", action='store_true')

    args = parser.parse_args()
    img_dir = args.img_dir
    if args.ds_name == 'FFHQ_MV':
        img_name_pattern1 = os.path.join(img_dir, "ref_imgs/*.png")
        img_names1 = glob.glob(img_name_pattern1)
        img_name_pattern2 = os.path.join(img_dir, "mv_imgs/*.png")
        img_names2 = glob.glob(img_name_pattern2)
        img_names = img_names1 + img_names2
    elif args.ds_name == 'FFHQ':
        img_name_pattern = os.path.join(img_dir, "*.png")
        img_names = glob.glob(img_name_pattern)
    
    img_names = sorted(img_names)
    random.seed(args.seed)
    random.shuffle(img_names)

    process_id = args.process_id
    total_process = args.total_process
    if total_process > 1:
        assert process_id <= total_process -1
        num_samples_per_process = len(img_names) // total_process
        if process_id == total_process:
            img_names = img_names[process_id * num_samples_per_process : ]
        else:
            img_names = img_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
    
    if not args.reset:
        img_names = get_todo_img_names(img_names)
    print(f"todo images number: {len(img_names)}")

    for vid_name in multiprocess_run_tqdm(extract_segment_job ,img_names, desc=f"Root process {args.process_id}: extracting segment images", num_workers=args.num_workers):
        pass