Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import sys | |
import traceback | |
from math import ceil | |
import PIL.Image | |
import torch | |
import distinctipy | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import numpy as np | |
import facer | |
import tyro | |
from pixel3dmm import env_paths | |
colors = distinctipy.get_colors(22, rng=0) | |
def viz_results(img, seq_classes, n_classes, suppress_plot = False): | |
seg_img = np.zeros([img.shape[-2], img.shape[-1], 3]) | |
#distinctipy.color_swatch(colors) | |
bad_indices = [ | |
0, # background, | |
1, # neck | |
# 2, skin | |
3, # cloth | |
4, # ear_r (images-space r) | |
5, # ear_l | |
# 6 brow_r | |
# 7 brow_l | |
# 8, # eye_r | |
# 9, # eye_l | |
# 10 noise | |
# 11 mouth | |
# 12 lower_lip | |
# 13 upper_lip | |
14, # hair, | |
# 15, glasses | |
16, # ?? | |
17, # earring_r | |
18, # ? | |
] | |
bad_indices = [] | |
for i in range(n_classes): | |
if i not in bad_indices: | |
seg_img[seq_classes[0, :, :] == i] = np.array(colors[i])*255 | |
if not suppress_plot: | |
plt.imshow(seg_img.astype(np.uint(8))) | |
plt.show() | |
return Image.fromarray(seg_img.astype(np.uint8)) | |
def get_color_seg(img, seq_classes, n_classes): | |
seg_img = np.zeros([img.shape[-2], img.shape[-1], 3]) | |
colors = distinctipy.get_colors(n_classes+1, rng=0) | |
#distinctipy.color_swatch(colors) | |
bad_indices = [ | |
0, # background, | |
1, # neck | |
# 2, skin | |
3, # cloth | |
4, # ear_r (images-space r) | |
5, # ear_l | |
# 6 brow_r | |
# 7 brow_l | |
# 8, # eye_r | |
# 9, # eye_l | |
# 10 noise | |
# 11 mouth | |
# 12 lower_lip | |
# 13 upper_lip | |
14, # hair, | |
# 15, glasses | |
16, # ?? | |
17, # earring_r | |
18, # ? | |
] | |
for i in range(n_classes): | |
if i not in bad_indices: | |
seg_img[seq_classes[0, :, :] == i] = np.array(colors[i])*255 | |
return Image.fromarray(seg_img.astype(np.uint8)) | |
def crop_gt_img(img, seq_classes, n_classes): | |
seg_img = np.zeros([img.shape[-2], img.shape[-1], 3]) | |
colors = distinctipy.get_colors(n_classes+1, rng=0) | |
#distinctipy.color_swatch(colors) | |
bad_indices = [ | |
0, # background, | |
1, # neck | |
# 2, skin | |
3, # cloth | |
4, #ear_r (images-space r) | |
5, #ear_l | |
# 6 brow_r | |
# 7 brow_l | |
#8, # eye_r | |
#9, # eye_l | |
# 10 noise | |
# 11 mouth | |
# 12 lower_lip | |
# 13 upper_lip | |
14, # hair, | |
# 15, glasses | |
16, # ?? | |
17, # earring_r | |
18, # ? | |
] | |
for i in range(n_classes): | |
if i in bad_indices: | |
img[seq_classes[0, :, :] == i] = 0 | |
#plt.imshow(img.astype(np.uint(8))) | |
#plt.show() | |
return img.astype(np.uint8) | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
face_detector = facer.face_detector('retinaface/mobilenet', device=device) | |
face_parser = facer.face_parser('farl/celebm/448', device=device) # optional "farl/lapa/448" | |
def main(video_name : str): | |
out = f'{env_paths.PREPROCESSED_DATA}/{video_name}' | |
out_seg = f'{out}/seg_og/' | |
out_seg_annot = f'{out}/seg_non_crop_annotations/' | |
os.makedirs(out_seg, exist_ok=True) | |
os.makedirs(out_seg_annot, exist_ok=True) | |
folder = f'{out}/cropped/' # '/home/giebenhain/GTA/data_kinect/color/' | |
frames = [f for f in os.listdir(folder) if f.endswith('.png') or f.endswith('.jpg')] | |
frames.sort() | |
if len(os.listdir(out_seg)) == len(frames): | |
print(f''' | |
<<<<<<<< ALREADY COMPLETED SEGMENTATION FOR {video_name}, SKIPPING >>>>>>>> | |
''') | |
return | |
#for file in frames: | |
batch_size = 1 | |
for i in range(len(frames)//batch_size): | |
image_stack = [] | |
frame_stack = [] | |
original_shapes = [] | |
for j in range(batch_size): | |
file = frames[i * batch_size + j] | |
if os.path.exists(f'{out_seg_annot}/color_{file}.png'): | |
print('DONE') | |
continue | |
img = Image.open(f'{folder}/{file}')#.resize((512, 512)) | |
og_size = img.size | |
image = facer.hwc2bchw(torch.from_numpy(np.array(img)[..., :3])).to(device=device) # image: 1 x 3 x h x w | |
image_stack.append(image) | |
frame_stack.append(file[:-4]) | |
for batch_idx in range(ceil(len(image_stack)/batch_size)): | |
image_batch = torch.cat(image_stack[batch_idx*batch_size:(batch_idx+1)*batch_size], dim=0) | |
frame_idx_batch = frame_stack[batch_idx*batch_size:(batch_idx+1)*batch_size] | |
og_shape_batch = original_shapes[batch_idx*batch_size:(batch_idx+1)*batch_size] | |
#if True: | |
try: | |
with torch.inference_mode(): | |
faces = face_detector(image_batch) | |
torch.cuda.empty_cache() | |
faces = face_parser(image_batch, faces, bbox_scale_factor=1.25) | |
torch.cuda.empty_cache() | |
seg_logits = faces['seg']['logits'] | |
back_ground = torch.all(seg_logits == 0, dim=1, keepdim=True).detach().squeeze(1).cpu().numpy() | |
seg_probs = seg_logits.softmax(dim=1) # nfaces x nclasses x h x w | |
seg_classes = seg_probs.argmax(dim=1).detach().cpu().numpy().astype(np.uint8) | |
seg_classes[back_ground] = seg_probs.shape[1] + 1 | |
for _iidx in range(seg_probs.shape[0]): | |
frame = frame_idx_batch[_iidx] | |
iidx = faces['image_ids'][_iidx].item() | |
try: | |
I_color = viz_results(image_batch[iidx:iidx+1], seq_classes=seg_classes[_iidx:_iidx+1], n_classes=seg_probs.shape[1] + 1, suppress_plot=True) | |
I_color.save(f'{out_seg_annot}/color_{frame}.png') | |
except Exception as ex: | |
pass | |
I = Image.fromarray(seg_classes[_iidx]) | |
I.save(f'{out_seg}/{frame}.png') | |
torch.cuda.empty_cache() | |
except Exception as exx: | |
traceback.print_exc() | |
continue | |
if __name__ == '__main__': | |
tyro.cli(main) | |