from __future__ import annotations import json import torch from PIL import Image from torch.utils.data import Dataset from torchvision import transforms class BaseDataset(Dataset): def __init__(self, data_root, anno_file, target_height=320, target_width=576, num_frames=25): self.data_root = data_root assert target_height % 64 == 0 and target_width % 64 == 0, "Resize to integer multiple of 64" self.img_preprocessor = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0) ]) if isinstance(anno_file, list): self.samples = list() for each_file in anno_file: with open(each_file) as anno_json: self.samples += json.load(anno_json) else: with open(anno_file) as anno_json: self.samples = json.load(anno_json) self.target_height = target_height self.target_width = target_width self.num_frames = num_frames # self.log_cond_aug_dist = torch.distributions.Normal(-3.0, 0.5) def preprocess_image(self, image_path): image = Image.open(image_path) ori_w, ori_h = image.size if ori_w / ori_h > self.target_width / self.target_height: tmp_w = int(self.target_width / self.target_height * ori_h) left = (ori_w - tmp_w) // 2 right = (ori_w + tmp_w) // 2 image = image.crop((left, 0, right, ori_h)) elif ori_w / ori_h < self.target_width / self.target_height: tmp_h = int(self.target_height / self.target_width * ori_w) top = (ori_h - tmp_h) // 2 bottom = (ori_h + tmp_h) // 2 image = image.crop((0, top, ori_w, bottom)) image = image.resize((self.target_width, self.target_height), resample=Image.LANCZOS) if not image.mode == "RGB": image = image.convert("RGB") image = self.img_preprocessor(image) return image def get_image_path(self, sample_dict, current_index): pass def build_data_dict(self, image_seq, sample_dict): # log_cond_aug = self.log_cond_aug_dist.sample() # cond_aug = torch.exp(log_cond_aug) cond_aug = torch.tensor([0.0]) data_dict = { "img_seq": torch.stack(image_seq), "motion_bucket_id": torch.tensor([127]), "fps_id": torch.tensor([9]), "cond_frames_without_noise": image_seq[0], "cond_frames": image_seq[0] + cond_aug * torch.randn_like(image_seq[0]), "cond_aug": cond_aug } return data_dict def __len__(self): return len(self.samples) def __getitem__(self, index): sample_dict = self.samples[index] image_seq = list() for i in range(self.num_frames): current_index = i img_path = self.get_image_path(sample_dict, current_index) image = self.preprocess_image(img_path) image_seq.append(image) return self.build_data_dict(image_seq, sample_dict)