import glob |
import json |
import os |
import cv2 |
import pickle |
import random |
import re |
import subprocess |
from functools import partial |
import librosa.core |
import numpy as np |
import torch |
import torch.distributions |
import torch.distributed as dist |
import torch.optim |
import torch.utils.data |
from utils.commons.indexed_datasets import IndexedDataset |
from torch.utils.data import Dataset, DataLoader |
import torch.nn.functional as F |
import pandas as pd |
from tqdm import tqdm |
import csv |
from utils.commons.hparams import hparams, set_hparams |
from utils.commons.meters import Timer |
from data_util.face3d_helper import Face3DHelper |
from utils.audio import librosa_wav2mfcc |
from utils.commons.dataset_utils import collate_xd |
from utils.commons.tensor_utils import convert_to_tensor |
from data_gen.utils.process_video.extract_segment_imgs import decode_segmap_mask_from_image |
from data_gen.eg3d.convert_to_eg3d_convention import get_eg3d_convention_camera_pose_intrinsic |
from utils.commons.image_utils import load_image_as_uint8_tensor |
from modules.eg3ds.camera_utils.pose_sampler import UnifiedCameraPoseSampler |
def sample_idx(img_dir, num_frames): |
cnt = 0 |
while True: |
cnt += 1 |
if cnt > 1000: |
print(f"recycle for more than 1000 times, check this {img_dir}") |
idx = random.randint(0, num_frames-1) |
ret1 = find_img_name(img_dir, idx) |
if ret1 == 'None': |
continue |
ret2 = find_img_name(img_dir.replace("/gt_imgs/","/head_imgs/"), idx) |
if ret2 == 'None': |
continue |
ret3 = find_img_name(img_dir.replace("/gt_imgs/","/inpaint_torso_imgs/"), idx) |
if ret3 == 'None': |
continue |
ret4 = find_img_name(img_dir.replace("/gt_imgs/","/com_imgs/"), idx) |
if ret4 == 'None': |
continue |
return idx |
def find_img_name(img_dir, idx): |
gt_img_fname = os.path.join(img_dir, format(idx, "05d") + ".jpg") |
if not os.path.exists(gt_img_fname): |
gt_img_fname = os.path.join(img_dir, str(idx) + ".jpg") |
if not os.path.exists(gt_img_fname): |
gt_img_fname = os.path.join(img_dir, format(idx, "08d") + ".jpg") |
if not os.path.exists(gt_img_fname): |
gt_img_fname = os.path.join(img_dir, format(idx, "08d") + ".png") |
if not os.path.exists(gt_img_fname): |
gt_img_fname = os.path.join(img_dir, format(idx, "05d") + ".png") |
if not os.path.exists(gt_img_fname): |
gt_img_fname = os.path.join(img_dir, str(idx) + ".png") |
if os.path.exists(gt_img_fname): |
return gt_img_fname |
else: |
return 'None' |
def get_win_from_arr(arr, index, win_size): |
left = index - win_size//2 |
right = index + (win_size - win_size//2) |
pad_left = 0 |
pad_right = 0 |
if left < 0: |
pad_left = -left |
left = 0 |
if right > arr.shape[0]: |
pad_right = right - arr.shape[0] |
right = arr.shape[0] |
win = arr[left:right] |
if pad_left > 0: |
if isinstance(arr, np.ndarray): |
win = np.concatenate([np.zeros_like(win[:pad_left]), win], axis=0) |
else: |
win = torch.cat([torch.zeros_like(win[:pad_left]), win], dim=0) |
if pad_right > 0: |
if isinstance(arr, np.ndarray): |
win = np.concatenate([win, np.zeros_like(win[:pad_right])], axis=0) |
else: |
win = torch.cat([win, torch.zeros_like(win[:pad_right])], dim=0) |
return win |
class Img2Plane_Dataset(Dataset): |
def __init__(self, prefix='train', data_dir=None): |
self.db_key = prefix |
self.ds = None |
self.sizes = None |
self.x_maxframes = 200 |
self.face3d_helper = Face3DHelper('deep_3drecon/BFM') |
self.x_multiply = 8 |
self.hparams = hparams |
self.pose_sampler = UnifiedCameraPoseSampler() |
self.ds_path = self.hparams['binary_data_dir'] if data_dir is None else data_dir |
def __len__(self): |
ds = self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') |
return len(ds) |
def _get_item(self, index): |
""" |
This func is necessary to open files in multi-threads! |
""" |
if self.ds is None: |
self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') |
return self.ds[index] |
def __getitem__(self, idx): |
raw_item = self._get_item(idx) |
if raw_item is None: |
print("loading from binary data failed!") |
return None |
item = { |
'idx': idx, |
'item_name': raw_item['img_dir'], |
} |
img_dir = raw_item['img_dir'].replace('/com_imgs/', '/gt_imgs/') |
num_frames = len(raw_item['exp']) |
hparams = self.hparams |
camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler':convert_to_tensor(raw_item['euler']).cpu(), 'trans':convert_to_tensor(raw_item['trans']).cpu()}) |
c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics'] |
raw_item['c2w'] = c2w |
raw_item['intrinsics'] = intrinsics |
max_pitch = 10 / 180 * 3.1415926 |
min_pitch = -max_pitch |
pitch = random.random() * (max_pitch - min_pitch) + min_pitch |
max_yaw = 16 / 180 * 3.1415926 |
min_yaw = - max_yaw |
yaw = random.random() * (max_yaw - min_yaw) + min_yaw |
distance = random.random() * (3.2-2.7) + 2.7 |
ws_camera = self.pose_sampler.get_camera_pose(pitch, yaw, lookat_location=torch.tensor([0,0,0.2]), distance_to_orig=distance)[0] |
if hparams.get("random_sample_pose", False) is True and random.random() < 0.5 : |
max_pitch = 26 / 180 * 3.1415926 |
min_pitch = -max_pitch |
pitch = random.random() * (max_pitch - min_pitch) + min_pitch |
max_yaw = 38 / 180 * 3.1415926 |
min_yaw = - max_yaw |
yaw = random.random() * (max_yaw - min_yaw) + min_yaw |
distance = random.random() * (4.0-2.7) + 2.7 |
real_camera = self.pose_sampler.get_camera_pose(pitch, yaw, lookat_location=torch.tensor([0,0,0.2]), distance_to_orig=distance)[0] |
else: |
real_idx = sample_idx(img_dir, num_frames) |
real_c2w = raw_item['c2w'][real_idx] |
real_intrinsics = raw_item['intrinsics'][real_idx] |
real_camera = np.concatenate([real_c2w.reshape([16,]) , real_intrinsics.reshape([9,])], axis=0) |
real_camera = convert_to_tensor(real_camera) |
if hparams.get("random_sample_pose", False) is True and random.random() < 0.5 : |
max_pitch = 26 / 180 * 3.1415926 |
min_pitch = -max_pitch |
pitch = random.random() * (max_pitch - min_pitch) + min_pitch |
max_yaw = 38 / 180 * 3.1415926 |
min_yaw = - max_yaw |
yaw = random.random() * (max_yaw - min_yaw) + min_yaw |
distance = random.random() * (4.0-2.7) + 2.7 |
fake_camera = self.pose_sampler.get_camera_pose(pitch, yaw, lookat_location=torch.tensor([0,0,0.2]), distance_to_orig=distance)[0] |
else: |
fake_idx = sample_idx(img_dir, num_frames) |
fake_c2w = raw_item['c2w'][fake_idx] |
fake_intrinsics = raw_item['intrinsics'][fake_idx] |
fake_camera = np.concatenate([fake_c2w.reshape([16,]), fake_intrinsics.reshape([9,])], axis=0) |
fake_camera = convert_to_tensor(fake_camera) |
item.update({ |
'ws_camera': ws_camera, |
'real_camera': real_camera, |
'fake_camera': fake_camera, |
}) |
return item |
def get_dataloader(self, batch_size=1, num_workers=0): |
loader = DataLoader(self, pin_memory=True,collate_fn=self.collater, batch_size=batch_size, num_workers=num_workers) |
return loader |
def collater(self, samples): |
hparams = self.hparams |
if len(samples) == 0: |
return {} |
batch = {} |
batch['ffhq_ws_cameras'] = torch.stack([s['ws_camera'] for s in samples], dim=0) |
batch['ffhq_ref_cameras'] = torch.stack([s['real_camera'] for s in samples], dim=0) |
batch['ffhq_mv_cameras'] = torch.stack([s['fake_camera'] for s in samples], dim=0) |
return batch |
class Motion2Video_Dataset(Dataset): |
def __init__(self, prefix='train', data_dir=None): |
self.db_key = prefix |
self.ds = None |
self.sizes = None |
self.x_maxframes = 200 |
self.face3d_helper = Face3DHelper('deep_3drecon/BFM') |
self.x_multiply = 8 |
self.hparams = hparams |
self.ds_path = self.hparams['binary_data_dir'] if data_dir is None else data_dir |
def __len__(self): |
ds = self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') |
return len(ds) |
def _get_item(self, index): |
""" |
This func is necessary to open files in multi-threads! |
""" |
if self.ds is None: |
self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') |
return self.ds[index] |
def __getitem__(self, idx): |
raw_item = self._get_item(idx) |
if raw_item is None: |
print("loading from binary data failed!") |
return None |
item = { |
'idx': idx, |
'item_name': raw_item['img_dir'], |
} |
camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler':convert_to_tensor(raw_item['euler']).cpu(), 'trans':convert_to_tensor(raw_item['trans']).cpu()}) |
c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics'] |
raw_item['c2w'] = c2w |
raw_item['intrinsics'] = intrinsics |
img_dir = raw_item['img_dir'].replace('/com_imgs/', '/gt_imgs/') |
num_frames = len(raw_item['exp']) |
real_idx = sample_idx(img_dir, num_frames) |
real_c2w = raw_item['c2w'][real_idx] |
real_intrinsics = raw_item['intrinsics'][real_idx] |
real_camera = np.concatenate([real_c2w.reshape([16,]) , real_intrinsics.reshape([9,])], axis=0) |
real_camera = convert_to_tensor(real_camera) |
item['real_camera'] = real_camera |
gt_img_fname = find_img_name(img_dir, real_idx) |
gt_img = load_image_as_uint8_tensor(gt_img_fname)[..., :3] |
item['real_gt_img'] = gt_img.float() / 127.5 - 1 |
for key in ['head', 'com', 'inpaint_torso']: |
key_img_dir = img_dir.replace("/gt_imgs/",f"/{key}_imgs/") |
key_img_fname = find_img_name(key_img_dir, real_idx) |
key_img = load_image_as_uint8_tensor(key_img_fname)[..., :3] |
item[f'real_{key}_img'] = key_img.float() / 127.5 - 1 |
bg_img_name = img_dir.replace("/gt_imgs/",f"/bg_img/") + '.jpg' |
bg_img = load_image_as_uint8_tensor(bg_img_name)[..., :3] |
item[f'bg_img'] = bg_img.float() / 127.5 - 1 |
seg_img_name = gt_img_fname.replace("/gt_imgs/",f"/segmaps/").replace(".jpg", ".png") |
seg_img = cv2.imread(seg_img_name)[:,:, ::-1] |
segmap = torch.from_numpy(decode_segmap_mask_from_image(seg_img)) |
item[f'real_segmap'] = segmap |
item[f'real_head_mask'] = segmap[[1,3,5]].sum(dim=0) |
item[f'real_torso_mask'] = segmap[[2,4]].sum(dim=0) |
item.update({ |
'real_identity': convert_to_tensor(raw_item['id']).reshape([80,]), |
'real_expression': convert_to_tensor(raw_item['exp'][real_idx]).reshape([64,]), |
'real_euler': convert_to_tensor(raw_item['euler'][real_idx]).reshape([3,]), |
'real_trans': convert_to_tensor(raw_item['trans'][real_idx]).reshape([3,]), |
}) |
pertube_idx_candidates = [idx for idx in [real_idx-1, real_idx+1] if (idx>=0 and idx <= num_frames-1 )] |
pertube_idx = random.choice(pertube_idx_candidates) |
item[f'real_pertube_expression_1'] = convert_to_tensor(raw_item['exp'][pertube_idx]).reshape([64,]) |
item[f'real_pertube_expression_2'] = item['real_expression'] * 2 - item[f'real_pertube_expression_1'] |
fake_idx = sample_idx(img_dir, num_frames) |
min_offset = min(50, max((num_frames-1-fake_idx)//2, (fake_idx)//2)) |
while abs(fake_idx - real_idx) < min_offset: |
fake_idx = sample_idx(img_dir, num_frames) |
min_offset = min(50, max((num_frames-1-fake_idx)//2, (fake_idx)//2)) |
fake_c2w = raw_item['c2w'][fake_idx] |
fake_intrinsics = raw_item['intrinsics'][fake_idx] |
fake_camera = np.concatenate([fake_c2w.reshape([16,]) , fake_intrinsics.reshape([9,])], axis=0) |
fake_camera = convert_to_tensor(fake_camera) |
item['fake_camera'] = fake_camera |
gt_img_fname = find_img_name(img_dir, fake_idx) |
gt_img = load_image_as_uint8_tensor(gt_img_fname)[..., :3] |
item['fake_gt_img'] = gt_img.float() / 127.5 - 1 |
seg_img_name = gt_img_fname.replace("/gt_imgs/",f"/segmaps/").replace(".jpg", ".png") |
seg_img = cv2.imread(seg_img_name)[:,:, ::-1] |
segmap = torch.from_numpy(decode_segmap_mask_from_image(seg_img)) |
item[f'fake_segmap'] = segmap |
item[f'fake_head_mask'] = segmap[[1,3,5]].sum(dim=0) |
item[f'fake_torso_mask'] = segmap[[2,4]].sum(dim=0) |
for key in ['head', 'com', 'inpaint_torso']: |
key_img_dir = img_dir.replace("/gt_imgs/",f"/{key}_imgs/") |
key_img_fname = find_img_name(key_img_dir, fake_idx) |
key_img = load_image_as_uint8_tensor(key_img_fname)[..., :3] |
item[f'fake_{key}_img'] = key_img.float() / 127.5 - 1 |
item.update({ |
f'fake_identity': convert_to_tensor(raw_item['id']).reshape([80,]), |
f'fake_expression': convert_to_tensor(raw_item['exp'][fake_idx]).reshape([64,]), |
f'fake_euler': convert_to_tensor(raw_item['euler'][fake_idx]).reshape([3,]), |
f'fake_trans': convert_to_tensor(raw_item['trans'][fake_idx]).reshape([3,]), |
}) |
pertube_idx_candidates = [idx for idx in [fake_idx-1, fake_idx+1] if (idx>=0 and idx <= num_frames-1 )] |
pertube_idx = random.choice(pertube_idx_candidates) |
item[f'fake_pertube_expression_1'] = convert_to_tensor(raw_item['exp'][pertube_idx]).reshape([64,]) |
item[f'fake_pertube_expression_2'] = item['fake_expression'] * 2 - item[f'fake_pertube_expression_1'] |
return item |
def get_dataloader(self, batch_size=1, num_workers=0): |
loader = DataLoader(self, pin_memory=True,collate_fn=self.collater, batch_size=batch_size, num_workers=num_workers) |
return loader |
def collater(self, samples): |
hparams = self.hparams |
if len(samples) == 0: |
return {} |
batch = {} |
batch['th1kh_item_names'] = [s['item_name'] for s in samples] |
batch['th1kh_ref_gt_imgs'] = torch.stack([s['real_gt_img'] for s in samples]).permute(0,3,1,2) |
batch['th1kh_ref_head_masks'] = torch.stack([s['real_head_mask'] for s in samples]) |
batch['th1kh_ref_torso_masks'] = torch.stack([s['real_torso_mask'] for s in samples]) |
batch['th1kh_ref_segmaps'] = torch.stack([s['real_segmap'] for s in samples]) |
for key in ['head', 'com', 'inpaint_torso']: |
batch[f'th1kh_ref_{key}_imgs'] = torch.stack([s[f'real_{key}_img'] for s in samples]).permute(0,3,1,2) |
batch[f'th1kh_bg_imgs'] = torch.stack([s[f'bg_img'] for s in samples]).permute(0,3,1,2) |
batch['th1kh_ref_cameras'] = torch.stack([s['real_camera'] for s in samples], dim=0) |
batch['th1kh_ref_ids'] = torch.stack([s['real_identity'] for s in samples], dim=0) |
batch['th1kh_ref_exps'] = torch.stack([s['real_expression'] for s in samples], dim=0) |
batch['th1kh_ref_eulers'] = torch.stack([s['real_euler'] for s in samples], dim=0) |
batch['th1kh_ref_trans'] = torch.stack([s['real_trans'] for s in samples], dim=0) |
batch['th1kh_mv_gt_imgs'] = torch.stack([s['fake_gt_img'] for s in samples]).permute(0,3,1,2) |
for key in ['head', 'com', 'inpaint_torso']: |
batch[f'th1kh_mv_{key}_imgs'] = torch.stack([s[f'fake_{key}_img'] for s in samples]).permute(0,3,1,2) |
batch['th1kh_mv_head_masks'] = torch.stack([s['fake_head_mask'] for s in samples]) |
batch['th1kh_mv_torso_masks'] = torch.stack([s['fake_torso_mask'] for s in samples]) |
batch['th1kh_mv_cameras'] = torch.stack([s['fake_camera'] for s in samples], dim=0) |
batch['th1kh_mv_ids'] = torch.stack([s['fake_identity'] for s in samples], dim=0) |
batch['th1kh_mv_exps'] = torch.stack([s['fake_expression'] for s in samples], dim=0) |
batch['th1kh_mv_eulers'] = torch.stack([s['fake_euler'] for s in samples], dim=0) |
batch['th1kh_mv_trans'] = torch.stack([s['fake_trans'] for s in samples], dim=0) |
batch['th1kh_ref_pertube_exps_1'] = torch.stack([s['real_pertube_expression_1'] for s in samples], dim=0) |
batch['th1kh_ref_pertube_exps_2'] = torch.stack([s['real_pertube_expression_2'] for s in samples], dim=0) |
batch['th1kh_mv_pertube_exps_1'] = torch.stack([s['fake_pertube_expression_1'] for s in samples], dim=0) |
batch['th1kh_mv_pertube_exps_2'] = torch.stack([s['fake_pertube_expression_2'] for s in samples], dim=0) |
return batch |
if __name__ == '__main__': |
os.environ["OMP_NUM_THREADS"] = "1" |
ds = Img2Plane_Dataset("train", 'data/binary/th1kh') |
dl = ds.get_dataloader() |
for b in tqdm(dl): |
pass |