Spaces:
Runtime error
Runtime error
import json | |
import os, io, csv, math, random | |
import numpy as np | |
import torchvision | |
from einops import rearrange | |
from decord import VideoReader | |
import torch | |
import torchvision.transforms as transforms | |
from torch.utils.data.dataset import Dataset | |
from tqdm import tqdm | |
from opensora.utils.dataset_utils import DecordInit | |
from opensora.utils.utils import text_preprocessing | |
class T2V_dataset(Dataset): | |
def __init__(self, args, transform, temporal_sample, tokenizer): | |
# with open(args.data_path, 'r') as csvfile: | |
# self.samples = list(csv.DictReader(csvfile)) | |
self.video_folder = args.video_folder | |
self.num_frames = args.num_frames | |
self.transform = transform | |
self.temporal_sample = temporal_sample | |
self.tokenizer = tokenizer | |
self.model_max_length = args.model_max_length | |
self.v_decoder = DecordInit() | |
with open(args.data_path, 'r') as f: | |
self.samples = json.load(f) | |
self.use_image_num = args.use_image_num | |
self.use_img_from_vid = args.use_img_from_vid | |
if self.use_image_num != 0 and not self.use_img_from_vid: | |
self.img_cap_list = self.get_img_cap_list() | |
def __len__(self): | |
return len(self.samples) | |
def __getitem__(self, idx): | |
try: | |
# video = torch.randn(3, 16, 128, 128) | |
# input_ids = torch.ones(1, 120).to(torch.long).squeeze(0) | |
# cond_mask = torch.cat([torch.ones(1, 60).to(torch.long), torch.ones(1, 60).to(torch.long)], dim=1).squeeze(0) | |
# return video, input_ids, cond_mask | |
video_path = self.samples[idx]['path'] | |
video = self.decord_read(video_path) | |
video = self.transform(video) # T C H W -> T C H W | |
video = video.transpose(0, 1) # T C H W -> C T H W | |
text = self.samples[idx]['cap'][0] | |
text = text_preprocessing(text) | |
text_tokens_and_mask = self.tokenizer( | |
text, | |
max_length=self.model_max_length, | |
padding='max_length', | |
truncation=True, | |
return_attention_mask=True, | |
add_special_tokens=True, | |
return_tensors='pt' | |
) | |
input_ids = text_tokens_and_mask['input_ids'].squeeze(0) | |
cond_mask = text_tokens_and_mask['attention_mask'].squeeze(0) | |
if self.use_image_num != 0 and self.use_img_from_vid: | |
select_image_idx = np.linspace(0, self.num_frames-1, self.use_image_num, dtype=int) | |
assert self.num_frames >= self.use_image_num | |
images = video[:, select_image_idx] # c, num_img, h, w | |
video = torch.cat([video, images], dim=1) # c, num_frame+num_img, h, w | |
input_ids = torch.stack([input_ids] * (1+self.use_image_num)) # 1+self.use_image_num, l | |
cond_mask = torch.stack([cond_mask] * (1+self.use_image_num)) # 1+self.use_image_num, l | |
elif self.use_image_num != 0 and not self.use_img_from_vid: | |
images, captions = self.img_cap_list[idx] | |
raise NotImplementedError | |
else: | |
pass | |
return video, input_ids, cond_mask | |
except Exception as e: | |
print(f'Error with {e}, {self.samples[idx]}') | |
return self.__getitem__(random.randint(0, self.__len__() - 1)) | |
def tv_read(self, path): | |
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW') | |
total_frames = len(vframes) | |
# Sampling video frames | |
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) | |
# assert end_frame_ind - start_frame_ind >= self.num_frames | |
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) | |
video = vframes[frame_indice] # (T, C, H, W) | |
return video | |
def decord_read(self, path): | |
decord_vr = self.v_decoder(path) | |
total_frames = len(decord_vr) | |
# Sampling video frames | |
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) | |
# assert end_frame_ind - start_frame_ind >= self.num_frames | |
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) | |
video_data = decord_vr.get_batch(frame_indice).asnumpy() | |
video_data = torch.from_numpy(video_data) | |
video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W) | |
return video_data | |
def get_img_cap_list(self): | |
raise NotImplementedError |