pixel3dmm / scripts /run_facer_segmentation.py
alexnasa's picture
Upload 66 files
cf92dec verified
raw
history blame
6.31 kB
import os
import sys
import traceback
from math import ceil
import PIL.Image
import torch
import distinctipy
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import facer
import tyro
from pixel3dmm import env_paths
colors = distinctipy.get_colors(22, rng=0)
def viz_results(img, seq_classes, n_classes, suppress_plot = False):
seg_img = np.zeros([img.shape[-2], img.shape[-1], 3])
#distinctipy.color_swatch(colors)
bad_indices = [
0, # background,
1, # neck
# 2, skin
3, # cloth
4, # ear_r (images-space r)
5, # ear_l
# 6 brow_r
# 7 brow_l
# 8, # eye_r
# 9, # eye_l
# 10 noise
# 11 mouth
# 12 lower_lip
# 13 upper_lip
14, # hair,
# 15, glasses
16, # ??
17, # earring_r
18, # ?
]
bad_indices = []
for i in range(n_classes):
if i not in bad_indices:
seg_img[seq_classes[0, :, :] == i] = np.array(colors[i])*255
if not suppress_plot:
plt.imshow(seg_img.astype(np.uint(8)))
plt.show()
return Image.fromarray(seg_img.astype(np.uint8))
def get_color_seg(img, seq_classes, n_classes):
seg_img = np.zeros([img.shape[-2], img.shape[-1], 3])
colors = distinctipy.get_colors(n_classes+1, rng=0)
#distinctipy.color_swatch(colors)
bad_indices = [
0, # background,
1, # neck
# 2, skin
3, # cloth
4, # ear_r (images-space r)
5, # ear_l
# 6 brow_r
# 7 brow_l
# 8, # eye_r
# 9, # eye_l
# 10 noise
# 11 mouth
# 12 lower_lip
# 13 upper_lip
14, # hair,
# 15, glasses
16, # ??
17, # earring_r
18, # ?
]
for i in range(n_classes):
if i not in bad_indices:
seg_img[seq_classes[0, :, :] == i] = np.array(colors[i])*255
return Image.fromarray(seg_img.astype(np.uint8))
def crop_gt_img(img, seq_classes, n_classes):
seg_img = np.zeros([img.shape[-2], img.shape[-1], 3])
colors = distinctipy.get_colors(n_classes+1, rng=0)
#distinctipy.color_swatch(colors)
bad_indices = [
0, # background,
1, # neck
# 2, skin
3, # cloth
4, #ear_r (images-space r)
5, #ear_l
# 6 brow_r
# 7 brow_l
#8, # eye_r
#9, # eye_l
# 10 noise
# 11 mouth
# 12 lower_lip
# 13 upper_lip
14, # hair,
# 15, glasses
16, # ??
17, # earring_r
18, # ?
]
for i in range(n_classes):
if i in bad_indices:
img[seq_classes[0, :, :] == i] = 0
#plt.imshow(img.astype(np.uint(8)))
#plt.show()
return img.astype(np.uint8)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
face_detector = facer.face_detector('retinaface/mobilenet', device=device)
face_parser = facer.face_parser('farl/celebm/448', device=device) # optional "farl/lapa/448"
def main(video_name : str):
out = f'{env_paths.PREPROCESSED_DATA}/{video_name}'
out_seg = f'{out}/seg_og/'
out_seg_annot = f'{out}/seg_non_crop_annotations/'
os.makedirs(out_seg, exist_ok=True)
os.makedirs(out_seg_annot, exist_ok=True)
folder = f'{out}/cropped/' # '/home/giebenhain/GTA/data_kinect/color/'
frames = [f for f in os.listdir(folder) if f.endswith('.png') or f.endswith('.jpg')]
frames.sort()
if len(os.listdir(out_seg)) == len(frames):
print(f'''
<<<<<<<< ALREADY COMPLETED SEGMENTATION FOR {video_name}, SKIPPING >>>>>>>>
''')
return
#for file in frames:
batch_size = 1
for i in range(len(frames)//batch_size):
image_stack = []
frame_stack = []
original_shapes = []
for j in range(batch_size):
file = frames[i * batch_size + j]
if os.path.exists(f'{out_seg_annot}/color_{file}.png'):
print('DONE')
continue
img = Image.open(f'{folder}/{file}')#.resize((512, 512))
og_size = img.size
image = facer.hwc2bchw(torch.from_numpy(np.array(img)[..., :3])).to(device=device) # image: 1 x 3 x h x w
image_stack.append(image)
frame_stack.append(file[:-4])
for batch_idx in range(ceil(len(image_stack)/batch_size)):
image_batch = torch.cat(image_stack[batch_idx*batch_size:(batch_idx+1)*batch_size], dim=0)
frame_idx_batch = frame_stack[batch_idx*batch_size:(batch_idx+1)*batch_size]
og_shape_batch = original_shapes[batch_idx*batch_size:(batch_idx+1)*batch_size]
#if True:
try:
with torch.inference_mode():
faces = face_detector(image_batch)
torch.cuda.empty_cache()
faces = face_parser(image_batch, faces, bbox_scale_factor=1.25)
torch.cuda.empty_cache()
seg_logits = faces['seg']['logits']
back_ground = torch.all(seg_logits == 0, dim=1, keepdim=True).detach().squeeze(1).cpu().numpy()
seg_probs = seg_logits.softmax(dim=1) # nfaces x nclasses x h x w
seg_classes = seg_probs.argmax(dim=1).detach().cpu().numpy().astype(np.uint8)
seg_classes[back_ground] = seg_probs.shape[1] + 1
for _iidx in range(seg_probs.shape[0]):
frame = frame_idx_batch[_iidx]
iidx = faces['image_ids'][_iidx].item()
try:
I_color = viz_results(image_batch[iidx:iidx+1], seq_classes=seg_classes[_iidx:_iidx+1], n_classes=seg_probs.shape[1] + 1, suppress_plot=True)
I_color.save(f'{out_seg_annot}/color_{frame}.png')
except Exception as ex:
pass
I = Image.fromarray(seg_classes[_iidx])
I.save(f'{out_seg}/{frame}.png')
torch.cuda.empty_cache()
except Exception as exx:
traceback.print_exc()
continue
if __name__ == '__main__':
tyro.cli(main)