|
import glob |
|
import json |
|
import os |
|
import cv2 |
|
import pickle |
|
import random |
|
import re |
|
import subprocess |
|
from functools import partial |
|
|
|
import librosa.core |
|
import numpy as np |
|
import torch |
|
import torch.distributions |
|
import torch.distributed as dist |
|
import torch.optim |
|
import torch.utils.data |
|
|
|
from utils.commons.indexed_datasets import IndexedDataset |
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
import torch.nn.functional as F |
|
import pandas as pd |
|
from tqdm import tqdm |
|
import csv |
|
from utils.commons.hparams import hparams, set_hparams |
|
from utils.commons.meters import Timer |
|
from data_util.face3d_helper import Face3DHelper |
|
from utils.audio import librosa_wav2mfcc |
|
from utils.commons.dataset_utils import collate_xd |
|
from utils.commons.tensor_utils import convert_to_tensor |
|
from data_gen.utils.process_video.extract_segment_imgs import decode_segmap_mask_from_image |
|
from data_gen.eg3d.convert_to_eg3d_convention import get_eg3d_convention_camera_pose_intrinsic |
|
from utils.commons.image_utils import load_image_as_uint8_tensor |
|
from modules.eg3ds.camera_utils.pose_sampler import UnifiedCameraPoseSampler |
|
|
|
|
|
def sample_idx(img_dir, num_frames): |
|
cnt = 0 |
|
while True: |
|
cnt += 1 |
|
if cnt > 1000: |
|
print(f"recycle for more than 1000 times, check this {img_dir}") |
|
idx = random.randint(0, num_frames-1) |
|
ret1 = find_img_name(img_dir, idx) |
|
if ret1 == 'None': |
|
continue |
|
ret2 = find_img_name(img_dir.replace("/gt_imgs/","/head_imgs/"), idx) |
|
if ret2 == 'None': |
|
continue |
|
ret3 = find_img_name(img_dir.replace("/gt_imgs/","/inpaint_torso_imgs/"), idx) |
|
if ret3 == 'None': |
|
continue |
|
ret4 = find_img_name(img_dir.replace("/gt_imgs/","/com_imgs/"), idx) |
|
if ret4 == 'None': |
|
continue |
|
return idx |
|
|
|
|
|
def find_img_name(img_dir, idx): |
|
gt_img_fname = os.path.join(img_dir, format(idx, "05d") + ".jpg") |
|
if not os.path.exists(gt_img_fname): |
|
gt_img_fname = os.path.join(img_dir, str(idx) + ".jpg") |
|
if not os.path.exists(gt_img_fname): |
|
gt_img_fname = os.path.join(img_dir, format(idx, "08d") + ".jpg") |
|
if not os.path.exists(gt_img_fname): |
|
gt_img_fname = os.path.join(img_dir, format(idx, "08d") + ".png") |
|
if not os.path.exists(gt_img_fname): |
|
gt_img_fname = os.path.join(img_dir, format(idx, "05d") + ".png") |
|
if not os.path.exists(gt_img_fname): |
|
gt_img_fname = os.path.join(img_dir, str(idx) + ".png") |
|
if os.path.exists(gt_img_fname): |
|
return gt_img_fname |
|
else: |
|
return 'None' |
|
|
|
|
|
def get_win_from_arr(arr, index, win_size): |
|
left = index - win_size//2 |
|
right = index + (win_size - win_size//2) |
|
pad_left = 0 |
|
pad_right = 0 |
|
if left < 0: |
|
pad_left = -left |
|
left = 0 |
|
if right > arr.shape[0]: |
|
pad_right = right - arr.shape[0] |
|
right = arr.shape[0] |
|
win = arr[left:right] |
|
if pad_left > 0: |
|
if isinstance(arr, np.ndarray): |
|
win = np.concatenate([np.zeros_like(win[:pad_left]), win], axis=0) |
|
else: |
|
win = torch.cat([torch.zeros_like(win[:pad_left]), win], dim=0) |
|
if pad_right > 0: |
|
if isinstance(arr, np.ndarray): |
|
win = np.concatenate([win, np.zeros_like(win[:pad_right])], axis=0) |
|
else: |
|
win = torch.cat([win, torch.zeros_like(win[:pad_right])], dim=0) |
|
return win |
|
|
|
|
|
class Img2Plane_Dataset(Dataset): |
|
def __init__(self, prefix='train', data_dir=None): |
|
self.db_key = prefix |
|
self.ds = None |
|
self.sizes = None |
|
self.x_maxframes = 200 |
|
self.face3d_helper = Face3DHelper('deep_3drecon/BFM') |
|
self.x_multiply = 8 |
|
self.hparams = hparams |
|
self.pose_sampler = UnifiedCameraPoseSampler() |
|
self.ds_path = self.hparams['binary_data_dir'] if data_dir is None else data_dir |
|
|
|
def __len__(self): |
|
ds = self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') |
|
return len(ds) |
|
|
|
def _get_item(self, index): |
|
""" |
|
This func is necessary to open files in multi-threads! |
|
""" |
|
if self.ds is None: |
|
self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') |
|
return self.ds[index] |
|
|
|
def __getitem__(self, idx): |
|
raw_item = self._get_item(idx) |
|
if raw_item is None: |
|
print("loading from binary data failed!") |
|
return None |
|
item = { |
|
'idx': idx, |
|
'item_name': raw_item['img_dir'], |
|
} |
|
img_dir = raw_item['img_dir'].replace('/com_imgs/', '/gt_imgs/') |
|
num_frames = len(raw_item['exp']) |
|
|
|
hparams = self.hparams |
|
camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler':convert_to_tensor(raw_item['euler']).cpu(), 'trans':convert_to_tensor(raw_item['trans']).cpu()}) |
|
c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics'] |
|
raw_item['c2w'] = c2w |
|
raw_item['intrinsics'] = intrinsics |
|
|
|
|
|
max_pitch = 10 / 180 * 3.1415926 |
|
min_pitch = -max_pitch |
|
pitch = random.random() * (max_pitch - min_pitch) + min_pitch |
|
max_yaw = 16 / 180 * 3.1415926 |
|
min_yaw = - max_yaw |
|
yaw = random.random() * (max_yaw - min_yaw) + min_yaw |
|
distance = random.random() * (3.2-2.7) + 2.7 |
|
ws_camera = self.pose_sampler.get_camera_pose(pitch, yaw, lookat_location=torch.tensor([0,0,0.2]), distance_to_orig=distance)[0] |
|
|
|
if hparams.get("random_sample_pose", False) is True and random.random() < 0.5 : |
|
max_pitch = 26 / 180 * 3.1415926 |
|
min_pitch = -max_pitch |
|
pitch = random.random() * (max_pitch - min_pitch) + min_pitch |
|
max_yaw = 38 / 180 * 3.1415926 |
|
min_yaw = - max_yaw |
|
yaw = random.random() * (max_yaw - min_yaw) + min_yaw |
|
distance = random.random() * (4.0-2.7) + 2.7 |
|
real_camera = self.pose_sampler.get_camera_pose(pitch, yaw, lookat_location=torch.tensor([0,0,0.2]), distance_to_orig=distance)[0] |
|
else: |
|
real_idx = sample_idx(img_dir, num_frames) |
|
real_c2w = raw_item['c2w'][real_idx] |
|
real_intrinsics = raw_item['intrinsics'][real_idx] |
|
real_camera = np.concatenate([real_c2w.reshape([16,]) , real_intrinsics.reshape([9,])], axis=0) |
|
real_camera = convert_to_tensor(real_camera) |
|
|
|
if hparams.get("random_sample_pose", False) is True and random.random() < 0.5 : |
|
max_pitch = 26 / 180 * 3.1415926 |
|
min_pitch = -max_pitch |
|
pitch = random.random() * (max_pitch - min_pitch) + min_pitch |
|
max_yaw = 38 / 180 * 3.1415926 |
|
min_yaw = - max_yaw |
|
yaw = random.random() * (max_yaw - min_yaw) + min_yaw |
|
distance = random.random() * (4.0-2.7) + 2.7 |
|
fake_camera = self.pose_sampler.get_camera_pose(pitch, yaw, lookat_location=torch.tensor([0,0,0.2]), distance_to_orig=distance)[0] |
|
else: |
|
fake_idx = sample_idx(img_dir, num_frames) |
|
fake_c2w = raw_item['c2w'][fake_idx] |
|
fake_intrinsics = raw_item['intrinsics'][fake_idx] |
|
fake_camera = np.concatenate([fake_c2w.reshape([16,]), fake_intrinsics.reshape([9,])], axis=0) |
|
fake_camera = convert_to_tensor(fake_camera) |
|
|
|
item.update({ |
|
'ws_camera': ws_camera, |
|
'real_camera': real_camera, |
|
'fake_camera': fake_camera, |
|
|
|
}) |
|
|
|
return item |
|
|
|
def get_dataloader(self, batch_size=1, num_workers=0): |
|
loader = DataLoader(self, pin_memory=True,collate_fn=self.collater, batch_size=batch_size, num_workers=num_workers) |
|
return loader |
|
|
|
def collater(self, samples): |
|
hparams = self.hparams |
|
if len(samples) == 0: |
|
return {} |
|
batch = {} |
|
|
|
batch['ffhq_ws_cameras'] = torch.stack([s['ws_camera'] for s in samples], dim=0) |
|
batch['ffhq_ref_cameras'] = torch.stack([s['real_camera'] for s in samples], dim=0) |
|
batch['ffhq_mv_cameras'] = torch.stack([s['fake_camera'] for s in samples], dim=0) |
|
return batch |
|
|
|
|
|
|
|
class Motion2Video_Dataset(Dataset): |
|
def __init__(self, prefix='train', data_dir=None): |
|
self.db_key = prefix |
|
self.ds = None |
|
self.sizes = None |
|
self.x_maxframes = 200 |
|
self.face3d_helper = Face3DHelper('deep_3drecon/BFM') |
|
self.x_multiply = 8 |
|
self.hparams = hparams |
|
self.ds_path = self.hparams['binary_data_dir'] if data_dir is None else data_dir |
|
|
|
def __len__(self): |
|
ds = self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') |
|
return len(ds) |
|
|
|
def _get_item(self, index): |
|
""" |
|
This func is necessary to open files in multi-threads! |
|
""" |
|
if self.ds is None: |
|
self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') |
|
return self.ds[index] |
|
|
|
def __getitem__(self, idx): |
|
raw_item = self._get_item(idx) |
|
if raw_item is None: |
|
print("loading from binary data failed!") |
|
return None |
|
item = { |
|
'idx': idx, |
|
'item_name': raw_item['img_dir'], |
|
} |
|
|
|
camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler':convert_to_tensor(raw_item['euler']).cpu(), 'trans':convert_to_tensor(raw_item['trans']).cpu()}) |
|
c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics'] |
|
raw_item['c2w'] = c2w |
|
raw_item['intrinsics'] = intrinsics |
|
|
|
img_dir = raw_item['img_dir'].replace('/com_imgs/', '/gt_imgs/') |
|
num_frames = len(raw_item['exp']) |
|
|
|
|
|
real_idx = sample_idx(img_dir, num_frames) |
|
real_c2w = raw_item['c2w'][real_idx] |
|
|
|
real_intrinsics = raw_item['intrinsics'][real_idx] |
|
real_camera = np.concatenate([real_c2w.reshape([16,]) , real_intrinsics.reshape([9,])], axis=0) |
|
real_camera = convert_to_tensor(real_camera) |
|
item['real_camera'] = real_camera |
|
|
|
gt_img_fname = find_img_name(img_dir, real_idx) |
|
gt_img = load_image_as_uint8_tensor(gt_img_fname)[..., :3] |
|
item['real_gt_img'] = gt_img.float() / 127.5 - 1 |
|
|
|
for key in ['head', 'com', 'inpaint_torso']: |
|
key_img_dir = img_dir.replace("/gt_imgs/",f"/{key}_imgs/") |
|
key_img_fname = find_img_name(key_img_dir, real_idx) |
|
key_img = load_image_as_uint8_tensor(key_img_fname)[..., :3] |
|
item[f'real_{key}_img'] = key_img.float() / 127.5 - 1 |
|
bg_img_name = img_dir.replace("/gt_imgs/",f"/bg_img/") + '.jpg' |
|
bg_img = load_image_as_uint8_tensor(bg_img_name)[..., :3] |
|
item[f'bg_img'] = bg_img.float() / 127.5 - 1 |
|
|
|
seg_img_name = gt_img_fname.replace("/gt_imgs/",f"/segmaps/").replace(".jpg", ".png") |
|
seg_img = cv2.imread(seg_img_name)[:,:, ::-1] |
|
segmap = torch.from_numpy(decode_segmap_mask_from_image(seg_img)) |
|
item[f'real_segmap'] = segmap |
|
item[f'real_head_mask'] = segmap[[1,3,5]].sum(dim=0) |
|
item[f'real_torso_mask'] = segmap[[2,4]].sum(dim=0) |
|
item.update({ |
|
|
|
'real_identity': convert_to_tensor(raw_item['id']).reshape([80,]), |
|
|
|
'real_expression': convert_to_tensor(raw_item['exp'][real_idx]).reshape([64,]), |
|
'real_euler': convert_to_tensor(raw_item['euler'][real_idx]).reshape([3,]), |
|
'real_trans': convert_to_tensor(raw_item['trans'][real_idx]).reshape([3,]), |
|
}) |
|
|
|
pertube_idx_candidates = [idx for idx in [real_idx-1, real_idx+1] if (idx>=0 and idx <= num_frames-1 )] |
|
|
|
pertube_idx = random.choice(pertube_idx_candidates) |
|
item[f'real_pertube_expression_1'] = convert_to_tensor(raw_item['exp'][pertube_idx]).reshape([64,]) |
|
item[f'real_pertube_expression_2'] = item['real_expression'] * 2 - item[f'real_pertube_expression_1'] |
|
|
|
|
|
fake_idx = sample_idx(img_dir, num_frames) |
|
min_offset = min(50, max((num_frames-1-fake_idx)//2, (fake_idx)//2)) |
|
while abs(fake_idx - real_idx) < min_offset: |
|
fake_idx = sample_idx(img_dir, num_frames) |
|
min_offset = min(50, max((num_frames-1-fake_idx)//2, (fake_idx)//2)) |
|
fake_c2w = raw_item['c2w'][fake_idx] |
|
|
|
fake_intrinsics = raw_item['intrinsics'][fake_idx] |
|
fake_camera = np.concatenate([fake_c2w.reshape([16,]) , fake_intrinsics.reshape([9,])], axis=0) |
|
fake_camera = convert_to_tensor(fake_camera) |
|
item['fake_camera'] = fake_camera |
|
|
|
gt_img_fname = find_img_name(img_dir, fake_idx) |
|
gt_img = load_image_as_uint8_tensor(gt_img_fname)[..., :3] |
|
item['fake_gt_img'] = gt_img.float() / 127.5 - 1 |
|
seg_img_name = gt_img_fname.replace("/gt_imgs/",f"/segmaps/").replace(".jpg", ".png") |
|
seg_img = cv2.imread(seg_img_name)[:,:, ::-1] |
|
segmap = torch.from_numpy(decode_segmap_mask_from_image(seg_img)) |
|
item[f'fake_segmap'] = segmap |
|
item[f'fake_head_mask'] = segmap[[1,3,5]].sum(dim=0) |
|
item[f'fake_torso_mask'] = segmap[[2,4]].sum(dim=0) |
|
|
|
for key in ['head', 'com', 'inpaint_torso']: |
|
key_img_dir = img_dir.replace("/gt_imgs/",f"/{key}_imgs/") |
|
key_img_fname = find_img_name(key_img_dir, fake_idx) |
|
key_img = load_image_as_uint8_tensor(key_img_fname)[..., :3] |
|
item[f'fake_{key}_img'] = key_img.float() / 127.5 - 1 |
|
|
|
item.update({ |
|
|
|
f'fake_identity': convert_to_tensor(raw_item['id']).reshape([80,]), |
|
|
|
f'fake_expression': convert_to_tensor(raw_item['exp'][fake_idx]).reshape([64,]), |
|
f'fake_euler': convert_to_tensor(raw_item['euler'][fake_idx]).reshape([3,]), |
|
f'fake_trans': convert_to_tensor(raw_item['trans'][fake_idx]).reshape([3,]), |
|
}) |
|
|
|
|
|
pertube_idx_candidates = [idx for idx in [fake_idx-1, fake_idx+1] if (idx>=0 and idx <= num_frames-1 )] |
|
pertube_idx = random.choice(pertube_idx_candidates) |
|
item[f'fake_pertube_expression_1'] = convert_to_tensor(raw_item['exp'][pertube_idx]).reshape([64,]) |
|
item[f'fake_pertube_expression_2'] = item['fake_expression'] * 2 - item[f'fake_pertube_expression_1'] |
|
|
|
return item |
|
|
|
def get_dataloader(self, batch_size=1, num_workers=0): |
|
loader = DataLoader(self, pin_memory=True,collate_fn=self.collater, batch_size=batch_size, num_workers=num_workers) |
|
return loader |
|
|
|
def collater(self, samples): |
|
hparams = self.hparams |
|
if len(samples) == 0: |
|
return {} |
|
batch = {} |
|
|
|
batch['th1kh_item_names'] = [s['item_name'] for s in samples] |
|
batch['th1kh_ref_gt_imgs'] = torch.stack([s['real_gt_img'] for s in samples]).permute(0,3,1,2) |
|
|
|
batch['th1kh_ref_head_masks'] = torch.stack([s['real_head_mask'] for s in samples]) |
|
batch['th1kh_ref_torso_masks'] = torch.stack([s['real_torso_mask'] for s in samples]) |
|
batch['th1kh_ref_segmaps'] = torch.stack([s['real_segmap'] for s in samples]) |
|
|
|
for key in ['head', 'com', 'inpaint_torso']: |
|
batch[f'th1kh_ref_{key}_imgs'] = torch.stack([s[f'real_{key}_img'] for s in samples]).permute(0,3,1,2) |
|
batch[f'th1kh_bg_imgs'] = torch.stack([s[f'bg_img'] for s in samples]).permute(0,3,1,2) |
|
|
|
batch['th1kh_ref_cameras'] = torch.stack([s['real_camera'] for s in samples], dim=0) |
|
batch['th1kh_ref_ids'] = torch.stack([s['real_identity'] for s in samples], dim=0) |
|
batch['th1kh_ref_exps'] = torch.stack([s['real_expression'] for s in samples], dim=0) |
|
batch['th1kh_ref_eulers'] = torch.stack([s['real_euler'] for s in samples], dim=0) |
|
batch['th1kh_ref_trans'] = torch.stack([s['real_trans'] for s in samples], dim=0) |
|
|
|
batch['th1kh_mv_gt_imgs'] = torch.stack([s['fake_gt_img'] for s in samples]).permute(0,3,1,2) |
|
|
|
for key in ['head', 'com', 'inpaint_torso']: |
|
batch[f'th1kh_mv_{key}_imgs'] = torch.stack([s[f'fake_{key}_img'] for s in samples]).permute(0,3,1,2) |
|
|
|
batch['th1kh_mv_head_masks'] = torch.stack([s['fake_head_mask'] for s in samples]) |
|
batch['th1kh_mv_torso_masks'] = torch.stack([s['fake_torso_mask'] for s in samples]) |
|
batch['th1kh_mv_cameras'] = torch.stack([s['fake_camera'] for s in samples], dim=0) |
|
batch['th1kh_mv_ids'] = torch.stack([s['fake_identity'] for s in samples], dim=0) |
|
batch['th1kh_mv_exps'] = torch.stack([s['fake_expression'] for s in samples], dim=0) |
|
batch['th1kh_mv_eulers'] = torch.stack([s['fake_euler'] for s in samples], dim=0) |
|
batch['th1kh_mv_trans'] = torch.stack([s['fake_trans'] for s in samples], dim=0) |
|
|
|
batch['th1kh_ref_pertube_exps_1'] = torch.stack([s['real_pertube_expression_1'] for s in samples], dim=0) |
|
batch['th1kh_ref_pertube_exps_2'] = torch.stack([s['real_pertube_expression_2'] for s in samples], dim=0) |
|
batch['th1kh_mv_pertube_exps_1'] = torch.stack([s['fake_pertube_expression_1'] for s in samples], dim=0) |
|
batch['th1kh_mv_pertube_exps_2'] = torch.stack([s['fake_pertube_expression_2'] for s in samples], dim=0) |
|
|
|
return batch |
|
|
|
if __name__ == '__main__': |
|
os.environ["OMP_NUM_THREADS"] = "1" |
|
|
|
ds = Img2Plane_Dataset("train", 'data/binary/th1kh') |
|
|
|
dl = ds.get_dataloader() |
|
for b in tqdm(dl): |
|
pass |
|
|