|
import pickle |
|
import sys |
|
import os |
|
|
|
sys.path.append(os.getcwd()) |
|
|
|
import json |
|
from glob import glob |
|
from data_utils.utils import * |
|
import torch.utils.data as data |
|
from data_utils.consts import speaker_id |
|
from data_utils.lower_body import count_part |
|
import random |
|
from data_utils.rotation_conversion import axis_angle_to_matrix, matrix_to_rotation_6d |
|
|
|
with open('data_utils/hand_component.json') as file_obj: |
|
comp = json.load(file_obj) |
|
left_hand_c = np.asarray(comp['left']) |
|
right_hand_c = np.asarray(comp['right']) |
|
|
|
|
|
def to3d(data): |
|
left_hand_pose = np.einsum('bi,ij->bj', data[:, 75:87], left_hand_c[:12, :]) |
|
right_hand_pose = np.einsum('bi,ij->bj', data[:, 87:99], right_hand_c[:12, :]) |
|
data = np.concatenate((data[:, :75], left_hand_pose, right_hand_pose), axis=-1) |
|
return data |
|
|
|
|
|
class SmplxDataset(): |
|
''' |
|
creat a dataset for every segment and concat. |
|
''' |
|
|
|
def __init__(self, |
|
data_root, |
|
speaker, |
|
motion_fn, |
|
audio_fn, |
|
audio_sr, |
|
fps, |
|
feat_method='mel_spec', |
|
audio_feat_dim=64, |
|
audio_feat_win_size=None, |
|
|
|
train=True, |
|
load_all=False, |
|
split_trans_zero=False, |
|
limbscaling=False, |
|
num_frames=25, |
|
num_pre_frames=25, |
|
num_generate_length=25, |
|
context_info=False, |
|
convert_to_6d=False, |
|
expression=False, |
|
config=None, |
|
am=None, |
|
am_sr=None, |
|
whole_video=False |
|
): |
|
|
|
self.data_root = data_root |
|
self.speaker = speaker |
|
|
|
self.feat_method = feat_method |
|
self.audio_fn = audio_fn |
|
self.audio_sr = audio_sr |
|
self.fps = fps |
|
self.audio_feat_dim = audio_feat_dim |
|
self.audio_feat_win_size = audio_feat_win_size |
|
self.context_info = context_info |
|
self.convert_to_6d = convert_to_6d |
|
self.expression = expression |
|
|
|
self.train = train |
|
self.load_all = load_all |
|
self.split_trans_zero = split_trans_zero |
|
self.limbscaling = limbscaling |
|
self.num_frames = num_frames |
|
self.num_pre_frames = num_pre_frames |
|
self.num_generate_length = num_generate_length |
|
|
|
|
|
self.config = config |
|
self.am_sr = am_sr |
|
self.whole_video = whole_video |
|
load_mode = self.config.dataset_load_mode |
|
|
|
if load_mode == 'pickle': |
|
raise NotImplementedError |
|
|
|
elif load_mode == 'csv': |
|
import pickle |
|
with open(data_root, 'rb') as f: |
|
u = pickle._Unpickler(f) |
|
data = u.load() |
|
self.data = data[0] |
|
if self.load_all: |
|
self._load_npz_all() |
|
|
|
elif load_mode == 'json': |
|
self.annotations = glob(data_root + '/*pkl') |
|
if len(self.annotations) == 0: |
|
raise FileNotFoundError(data_root + ' are empty') |
|
self.annotations = sorted(self.annotations) |
|
self.img_name_list = self.annotations |
|
|
|
if self.load_all: |
|
self._load_them_all(am, am_sr, motion_fn) |
|
|
|
def _load_npz_all(self): |
|
self.loaded_data = {} |
|
self.complete_data = [] |
|
data = self.data |
|
shape = data['body_pose_axis'].shape[0] |
|
self.betas = data['betas'] |
|
self.img_name_list = [] |
|
for index in range(shape): |
|
img_name = f'{index:6d}' |
|
self.img_name_list.append(img_name) |
|
|
|
jaw_pose = data['jaw_pose'][index] |
|
leye_pose = data['leye_pose'][index] |
|
reye_pose = data['reye_pose'][index] |
|
global_orient = data['global_orient'][index] |
|
body_pose = data['body_pose_axis'][index] |
|
left_hand_pose = data['left_hand_pose'][index] |
|
right_hand_pose = data['right_hand_pose'][index] |
|
|
|
full_body = np.concatenate( |
|
(jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose)) |
|
assert full_body.shape[0] == 99 |
|
if self.convert_to_6d: |
|
full_body = to3d(full_body) |
|
full_body = torch.from_numpy(full_body) |
|
full_body = matrix_to_rotation_6d(axis_angle_to_matrix(full_body)) |
|
full_body = np.asarray(full_body) |
|
if self.expression: |
|
expression = data['expression'][index] |
|
full_body = np.concatenate((full_body, expression)) |
|
|
|
else: |
|
full_body = to3d(full_body) |
|
if self.expression: |
|
expression = data['expression'][index] |
|
full_body = np.concatenate((full_body, expression)) |
|
|
|
self.loaded_data[img_name] = full_body.reshape(-1) |
|
self.complete_data.append(full_body.reshape(-1)) |
|
|
|
self.complete_data = np.array(self.complete_data) |
|
|
|
if self.audio_feat_win_size is not None: |
|
self.audio_feat = get_mfcc_old(self.audio_fn).transpose(1, 0) |
|
|
|
else: |
|
if self.feat_method == 'mel_spec': |
|
self.audio_feat = get_melspec(self.audio_fn, fps=self.fps, sr=self.audio_sr, n_mels=self.audio_feat_dim) |
|
elif self.feat_method == 'mfcc': |
|
self.audio_feat = get_mfcc(self.audio_fn, |
|
smlpx=True, |
|
sr=self.audio_sr, |
|
n_mfcc=self.audio_feat_dim, |
|
win_size=self.audio_feat_win_size |
|
) |
|
|
|
def _load_them_all(self, am, am_sr, motion_fn): |
|
self.loaded_data = {} |
|
self.complete_data = [] |
|
f = open(motion_fn, 'rb+') |
|
data = pickle.load(f) |
|
|
|
self.betas = np.array(data['betas']) |
|
|
|
jaw_pose = np.array(data['jaw_pose']) |
|
leye_pose = np.array(data['leye_pose']) |
|
reye_pose = np.array(data['reye_pose']) |
|
global_orient = np.array(data['global_orient']).squeeze() |
|
body_pose = np.array(data['body_pose_axis']) |
|
left_hand_pose = np.array(data['left_hand_pose']) |
|
right_hand_pose = np.array(data['right_hand_pose']) |
|
|
|
full_body = np.concatenate( |
|
(jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose), axis=1) |
|
assert full_body.shape[1] == 99 |
|
|
|
|
|
if self.convert_to_6d: |
|
full_body = to3d(full_body) |
|
full_body = torch.from_numpy(full_body) |
|
full_body = matrix_to_rotation_6d(axis_angle_to_matrix(full_body.reshape(-1, 55, 3))).reshape(-1, 330) |
|
full_body = np.asarray(full_body) |
|
if self.expression: |
|
expression = np.array(data['expression']) |
|
full_body = np.concatenate((full_body, expression), axis=1) |
|
|
|
else: |
|
full_body = to3d(full_body) |
|
expression = np.array(data['expression']) |
|
full_body = np.concatenate((full_body, expression), axis=1) |
|
|
|
self.complete_data = full_body |
|
self.complete_data = np.array(self.complete_data) |
|
|
|
if self.audio_feat_win_size is not None: |
|
self.audio_feat = get_mfcc_old(self.audio_fn).transpose(1, 0) |
|
else: |
|
|
|
|
|
|
|
self.audio_feat = get_mfcc_ta(self.audio_fn, |
|
smlpx=True, |
|
fps=30, |
|
sr=self.audio_sr, |
|
n_mfcc=self.audio_feat_dim, |
|
win_size=self.audio_feat_win_size, |
|
type=self.feat_method, |
|
am=am, |
|
am_sr=am_sr, |
|
encoder_choice=self.config.Model.encoder_choice, |
|
) |
|
|
|
|
|
|
|
def get_dataset(self, normalization=False, normalize_stats=None, split='train'): |
|
|
|
class __Worker__(data.Dataset): |
|
def __init__(child, index_list, normalization, normalize_stats, split='train') -> None: |
|
super().__init__() |
|
child.index_list = index_list |
|
child.normalization = normalization |
|
child.normalize_stats = normalize_stats |
|
child.split = split |
|
|
|
def __getitem__(child, index): |
|
num_generate_length = self.num_generate_length |
|
num_pre_frames = self.num_pre_frames |
|
seq_len = num_generate_length + num_pre_frames |
|
|
|
|
|
index = child.index_list[index] |
|
index_new = index + random.randrange(0, 5, 3) |
|
if index_new + seq_len > self.complete_data.shape[0]: |
|
index_new = index |
|
index = index_new |
|
|
|
if child.split in ['val', 'pre', 'test'] or self.whole_video: |
|
index = 0 |
|
seq_len = self.complete_data.shape[0] |
|
seq_data = [] |
|
assert index + seq_len <= self.complete_data.shape[0] |
|
|
|
seq_data = self.complete_data[index:(index + seq_len), :] |
|
seq_data = np.array(seq_data) |
|
|
|
''' |
|
audio feature, |
|
''' |
|
if not self.context_info: |
|
if not self.whole_video: |
|
audio_feat = self.audio_feat[index:index + seq_len, ...] |
|
if audio_feat.shape[0] < seq_len: |
|
audio_feat = np.pad(audio_feat, [[0, seq_len - audio_feat.shape[0]], [0, 0]], |
|
mode='reflect') |
|
|
|
assert audio_feat.shape[0] == seq_len and audio_feat.shape[1] == self.audio_feat_dim |
|
else: |
|
audio_feat = self.audio_feat |
|
|
|
else: |
|
if self.audio_feat_win_size is None: |
|
audio_feat = self.audio_feat[index:index + seq_len + num_pre_frames, ...] |
|
if audio_feat.shape[0] < seq_len + num_pre_frames: |
|
audio_feat = np.pad(audio_feat, |
|
[[0, seq_len + self.num_frames - audio_feat.shape[0]], [0, 0]], |
|
mode='constant') |
|
|
|
assert audio_feat.shape[0] == self.num_frames + seq_len and audio_feat.shape[ |
|
1] == self.audio_feat_dim |
|
|
|
if child.normalization: |
|
data_mean = child.normalize_stats['mean'].reshape(1, -1) |
|
data_std = child.normalize_stats['std'].reshape(1, -1) |
|
seq_data[:, :330] = (seq_data[:, :330] - data_mean) / data_std |
|
if child.split in['train', 'test']: |
|
if self.convert_to_6d: |
|
if self.expression: |
|
data_sample = { |
|
'poses': seq_data[:, :330].astype(np.float).transpose(1, 0), |
|
'expression': seq_data[:, 330:].astype(np.float).transpose(1, 0), |
|
|
|
'aud_feat': audio_feat.astype(np.float).transpose(1, 0), |
|
'speaker': speaker_id[self.speaker], |
|
'betas': self.betas, |
|
'aud_file': self.audio_fn, |
|
} |
|
else: |
|
data_sample = { |
|
'poses': seq_data[:, :330].astype(np.float).transpose(1, 0), |
|
'nzero': seq_data[:, 330:].astype(np.float).transpose(1, 0), |
|
'aud_feat': audio_feat.astype(np.float).transpose(1, 0), |
|
'speaker': speaker_id[self.speaker], |
|
'betas': self.betas |
|
} |
|
else: |
|
if self.expression: |
|
data_sample = { |
|
'poses': seq_data[:, :165].astype(np.float).transpose(1, 0), |
|
'expression': seq_data[:, 165:].astype(np.float).transpose(1, 0), |
|
'aud_feat': audio_feat.astype(np.float).transpose(1, 0), |
|
|
|
'speaker': speaker_id[self.speaker], |
|
'aud_file': self.audio_fn, |
|
'betas': self.betas |
|
} |
|
else: |
|
data_sample = { |
|
'poses': seq_data.astype(np.float).transpose(1, 0), |
|
'aud_feat': audio_feat.astype(np.float).transpose(1, 0), |
|
'speaker': speaker_id[self.speaker], |
|
'betas': self.betas |
|
} |
|
return data_sample |
|
else: |
|
data_sample = { |
|
'poses': seq_data[:, :330].astype(np.float).transpose(1, 0), |
|
'expression': seq_data[:, 330:].astype(np.float).transpose(1, 0), |
|
|
|
'aud_feat': audio_feat.astype(np.float).transpose(1, 0), |
|
'aud_file': self.audio_fn, |
|
'speaker': speaker_id[self.speaker], |
|
'betas': self.betas |
|
} |
|
return data_sample |
|
def __len__(child): |
|
return len(child.index_list) |
|
|
|
if split == 'train': |
|
index_list = list( |
|
range(0, min(self.complete_data.shape[0], self.audio_feat.shape[0]) - self.num_generate_length - self.num_pre_frames, |
|
6)) |
|
elif split in ['val', 'test']: |
|
index_list = list([0]) |
|
if self.whole_video: |
|
index_list = list([0]) |
|
self.all_dataset = __Worker__(index_list, normalization, normalize_stats, split) |
|
|
|
def __len__(self): |
|
return len(self.img_name_list) |
|
|