import glob import json import os import pickle import random import re import subprocess from functools import partial import librosa.core import numpy as np import torch import torch.distributions import torch.distributed as dist import torch.optim import torch.utils.data from utils.commons.indexed_datasets import IndexedDataset from torch.utils.data import Dataset, DataLoader import torch.nn.functional as F import pandas as pd from tqdm import tqdm import csv from utils.commons.hparams import hparams, set_hparams from utils.commons.meters import Timer from data_util.face3d_helper import Face3DHelper from utils.audio import librosa_wav2mfcc from utils.commons.dataset_utils import collate_xd class SyncNet_Dataset(Dataset): def __init__(self, prefix='train', data_dir=None): self.hparams = hparams self.db_key = prefix self.ds_path = self.hparams['binary_data_dir'] if data_dir is None else data_dir self.ds = None self.sizes = None self.x_maxframes = 200 # 50 video frames self.face3d_helper = Face3DHelper('deep_3drecon/BFM') self.x_multiply = 8 def __len__(self): ds = self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') return len(ds) def _get_item(self, index): """ This func is necessary to open files in multi-threads! """ if self.ds is None: self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') return self.ds[index] def __getitem__(self, idx): raw_item = self._get_item(idx) if raw_item is None: print("loading from binary data failed!") return None item = { 'idx': idx, 'item_id': raw_item['img_dir'], 'id': torch.from_numpy(raw_item['id']).float(), # [T_x, c=80] 'exp': torch.from_numpy(raw_item['exp']).float(), # [T_x, c=80] } if item['id'].shape[0] == 1: # global_id item['id'] = item['id'].repeat([item['exp'].shape[0], 1]) item['hubert'] = torch.from_numpy(raw_item['hubert']).float() # [T_x, 1024] x_len = len(item['hubert']) y_len = x_len // 2 # video is 25fps item['id'] = item['id'][:y_len] item['exp'] = item['exp'][:y_len] # randomly select a fixed-length clip start_frames = random.randint(0, max(0, x_len - self.x_maxframes)) start_frames = start_frames // 2 * 2 item['hubert'] = item['hubert'][start_frames: start_frames + self.x_maxframes] item['id'] = item['id'][start_frames//2: start_frames//2 + self.x_maxframes//2] item['exp'] = item['exp'][start_frames//2: start_frames//2 + self.x_maxframes//2] return item def get_dataloader(self, batch_size=1, num_workers=0): loader = DataLoader(self, pin_memory=True,collate_fn=self.collater, batch_size=batch_size, num_workers=num_workers) return loader def collater(self, samples): if len(samples) == 0: return None x_len = max(s['hubert'].size(0) for s in samples) y_len = x_len // 2 batch = { 'item_id': [s['item_id'] for s in samples], } batch['hubert'] = collate_xd([s["hubert"] for s in samples], max_len=x_len, pad_idx=0) # [b, t_max_y, 64] batch['x_mask'] = (batch['hubert'].abs().sum(dim=-1) > 0).float() # [b, t_max_x] batch['id'] = collate_xd([s["id"] for s in samples], max_len=y_len, pad_idx=0) # [b, t_max, 1] batch['exp'] = collate_xd([s["exp"] for s in samples], max_len=y_len, pad_idx=0) # [b, t_max, 1] batch['y_mask'] = (batch['id'].abs().sum(dim=-1) > 0).float() # [b, t_max_y] return batch if __name__ == '__main__': os.environ["OMP_NUM_THREADS"] = "1" ds = SyncNet_Dataset("train", 'data/binary/th1kh') dl = ds.get_dataloader() for b in tqdm(dl): pass