|
import sys |
|
import os |
|
sys.path.append(os.getcwd()) |
|
import os |
|
from tqdm import tqdm |
|
from data_utils.utils import * |
|
import torch.utils.data as data |
|
from data_utils.mesh_dataset import SmplxDataset |
|
from transformers import Wav2Vec2Processor |
|
|
|
|
|
class MultiVidData(): |
|
def __init__(self, |
|
data_root, |
|
speakers, |
|
split='train', |
|
limbscaling=False, |
|
normalization=False, |
|
norm_method='new', |
|
split_trans_zero=False, |
|
num_frames=25, |
|
num_pre_frames=25, |
|
num_generate_length=None, |
|
aud_feat_win_size=None, |
|
aud_feat_dim=64, |
|
feat_method='mel_spec', |
|
context_info=False, |
|
smplx=False, |
|
audio_sr=16000, |
|
convert_to_6d=False, |
|
expression=False, |
|
config=None |
|
): |
|
self.data_root = data_root |
|
self.speakers = speakers |
|
self.split = split |
|
if split == 'pre': |
|
self.split = 'train' |
|
self.norm_method=norm_method |
|
self.normalization = normalization |
|
self.limbscaling = limbscaling |
|
self.convert_to_6d = convert_to_6d |
|
self.num_frames=num_frames |
|
self.num_pre_frames=num_pre_frames |
|
if num_generate_length is None: |
|
self.num_generate_length = num_frames |
|
else: |
|
self.num_generate_length = num_generate_length |
|
self.split_trans_zero=split_trans_zero |
|
|
|
dataset = SmplxDataset |
|
|
|
if self.split_trans_zero: |
|
self.trans_dataset_list = [] |
|
self.zero_dataset_list = [] |
|
else: |
|
self.all_dataset_list = [] |
|
self.dataset={} |
|
self.complete_data=[] |
|
self.config=config |
|
load_mode=self.config.dataset_load_mode |
|
|
|
|
|
if load_mode=='pickle': |
|
import pickle |
|
import subprocess |
|
|
|
|
|
|
|
|
|
|
|
|
|
f = open(self.split+config.Data.pklname, 'rb+') |
|
self.dataset=pickle.load(f) |
|
f.close() |
|
for key in self.dataset: |
|
self.complete_data.append(self.dataset[key].complete_data) |
|
|
|
|
|
|
|
elif load_mode=='csv': |
|
|
|
|
|
try: |
|
sys.path.append(self.config.config_root_path) |
|
from config import config_path |
|
from csv_parser import csv_parse |
|
|
|
except ImportError as e: |
|
print(f'err: {e}') |
|
raise ImportError('config root path error...') |
|
|
|
|
|
for speaker_name in self.speakers: |
|
|
|
df_intervals=None |
|
df_intervals=df_intervals[df_intervals['speaker']==speaker_name] |
|
df_intervals = df_intervals[df_intervals['dataset'] == self.split] |
|
|
|
print(f'speaker {speaker_name} train interval length: {len(df_intervals)}') |
|
for iter_index, (_, interval) in tqdm( |
|
(enumerate(df_intervals.iterrows())),desc=f'load {speaker_name}' |
|
): |
|
|
|
( |
|
interval_index, |
|
interval_speaker, |
|
interval_video_fn, |
|
interval_id, |
|
|
|
start_time, |
|
end_time, |
|
duration_time, |
|
start_time_10, |
|
over_flow_flag, |
|
short_dur_flag, |
|
|
|
big_video_dir, |
|
small_video_dir_name, |
|
speaker_video_path, |
|
|
|
voca_basename, |
|
json_basename, |
|
wav_basename, |
|
voca_top_clip_path, |
|
voca_json_clip_path, |
|
voca_wav_clip_path, |
|
|
|
audio_output_fn, |
|
image_output_path, |
|
pifpaf_output_path, |
|
mp_output_path, |
|
op_output_path, |
|
deca_output_path, |
|
pixie_output_path, |
|
cam_output_path, |
|
ours_output_path, |
|
merge_output_path, |
|
multi_output_path, |
|
gt_output_path, |
|
ours_images_path, |
|
pkl_fil_path, |
|
)=csv_parse(interval) |
|
|
|
if not os.path.exists(pkl_fil_path) or not os.path.exists(audio_output_fn): |
|
continue |
|
|
|
key=f'{interval_video_fn}/{small_video_dir_name}' |
|
self.dataset[key] = dataset( |
|
data_root=pkl_fil_path, |
|
speaker=speaker_name, |
|
audio_fn=audio_output_fn, |
|
audio_sr=audio_sr, |
|
fps=num_frames, |
|
feat_method=feat_method, |
|
audio_feat_dim=aud_feat_dim, |
|
train=(self.split == 'train'), |
|
load_all=True, |
|
split_trans_zero=self.split_trans_zero, |
|
limbscaling=self.limbscaling, |
|
num_frames=self.num_frames, |
|
num_pre_frames=self.num_pre_frames, |
|
num_generate_length=self.num_generate_length, |
|
audio_feat_win_size=aud_feat_win_size, |
|
context_info=context_info, |
|
convert_to_6d=convert_to_6d, |
|
expression=expression, |
|
config=self.config |
|
) |
|
self.complete_data.append(self.dataset[key].complete_data) |
|
|
|
|
|
|
|
elif load_mode=='json': |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme") |
|
am_sr = 16000 |
|
|
|
|
|
for speaker_name in self.speakers: |
|
speaker_root = os.path.join(self.data_root, speaker_name) |
|
|
|
videos=[v for v in os.listdir(speaker_root) ] |
|
print(videos) |
|
|
|
haode = huaide = 0 |
|
|
|
for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)): |
|
source_vid=vid |
|
|
|
vid_pth = os.path.join(speaker_root, source_vid, self.split) |
|
if smplx == 'pose': |
|
seqs = [s for s in os.listdir(vid_pth) if (s.startswith('clip'))] |
|
else: |
|
try: |
|
seqs = [s for s in os.listdir(vid_pth)] |
|
except: |
|
continue |
|
|
|
for s in seqs: |
|
seq_root=os.path.join(vid_pth, s) |
|
key = seq_root |
|
audio_fname = os.path.join(speaker_root, source_vid, self.split, s, '%s.wav' % (s)) |
|
motion_fname = os.path.join(speaker_root, source_vid, self.split, s, '%s.pkl' % (s)) |
|
if not os.path.isfile(audio_fname) or not os.path.isfile(motion_fname): |
|
huaide = huaide + 1 |
|
continue |
|
|
|
self.dataset[key]=dataset( |
|
data_root=seq_root, |
|
speaker=speaker_name, |
|
motion_fn=motion_fname, |
|
audio_fn=audio_fname, |
|
audio_sr=audio_sr, |
|
fps=num_frames, |
|
feat_method=feat_method, |
|
audio_feat_dim=aud_feat_dim, |
|
train=(self.split=='train'), |
|
load_all=True, |
|
split_trans_zero=self.split_trans_zero, |
|
limbscaling=self.limbscaling, |
|
num_frames=self.num_frames, |
|
num_pre_frames=self.num_pre_frames, |
|
num_generate_length=self.num_generate_length, |
|
audio_feat_win_size=aud_feat_win_size, |
|
context_info=context_info, |
|
convert_to_6d=convert_to_6d, |
|
expression=expression, |
|
config=self.config, |
|
am=am, |
|
am_sr=am_sr, |
|
whole_video=config.Data.whole_video |
|
) |
|
self.complete_data.append(self.dataset[key].complete_data) |
|
haode = haode + 1 |
|
print("huaide:{}, haode:{}".format(huaide, haode)) |
|
import pickle |
|
|
|
f = open(self.split+config.Data.pklname, 'wb') |
|
pickle.dump(self.dataset, f) |
|
f.close() |
|
|
|
|
|
self.complete_data=np.concatenate(self.complete_data, axis=0) |
|
|
|
|
|
self.normalize_stats = {} |
|
|
|
self.data_mean = None |
|
self.data_std = None |
|
|
|
def get_dataset(self): |
|
self.normalize_stats['mean'] = self.data_mean |
|
self.normalize_stats['std'] = self.data_std |
|
|
|
for key in list(self.dataset.keys()): |
|
if self.dataset[key].complete_data.shape[0] < self.num_generate_length: |
|
continue |
|
self.dataset[key].num_generate_length = self.num_generate_length |
|
self.dataset[key].get_dataset(self.normalization, self.normalize_stats, self.split) |
|
self.all_dataset_list.append(self.dataset[key].all_dataset) |
|
|
|
if self.split_trans_zero: |
|
self.trans_dataset = data.ConcatDataset(self.trans_dataset_list) |
|
self.zero_dataset = data.ConcatDataset(self.zero_dataset_list) |
|
else: |
|
self.all_dataset = data.ConcatDataset(self.all_dataset_list) |
|
|
|
|
|
|
|
|