import os import json import sys import copy import math import torch import decord import random import numpy as np from PIL import Image from decord import VideoReader from torch.utils.data import Dataset from llava.utils import master_print from typing import Dict, Optional, Sequence, List from llava.datasets.data_cfgs import data_configs from transformers import CLIPImageProcessor, SiglipImageProcessor from llava.mm_utils import get_frame_indices, process_anyres_image from torch.utils.data.dataloader import default_collate decord.bridge.set_bridge("torch") class TaskBaseDataset(Dataset): """ Implementation of base task dataset """ def __init__(self, anno_path=None, data_args=None, name=None, **kwargs): self.anno_path = anno_path self.data_args = data_args self.image_aspect_ratio = data_args.image_aspect_ratio self.image_grid_pinpoints = data_args.image_grid_pinpoints self.vis_processor = data_args.image_processor self.type = None self.name = name master_print(f"Loading dataset {name}...") if (anno_path is not None): if not hasattr(self, 'annotation'): self.annotation = json.load(open(anno_path, 'r')) master_print(f"Finish loading dataset {name} {len(self.annotation)} samples...") def __len__(self): return len(self.annotation) def collater(self, samples): return default_collate(samples) def text_preprocess(self, sources) -> List[List[Dict[str, str]]]: pass def vis_preprocess(self, vis_path) -> Image: pass @property def data_type(self): return self.type def __getitem__(self, i) -> Dict[str, torch.Tensor]: item = self.annotation[i] vis_path = item['vis_path'] if 'vis_path' in item else item['video_path'] ret = { 'images': self.vis_preprocess(vis_path), 'conversations': self.text_preprocess(item) } if 'id' in item: ret['id'] = item['id'] return ret class ImageTaskDataset(TaskBaseDataset): def __init__(self, anno_path=None, data_args=None, name=None): super().__init__(anno_path=anno_path, data_args=data_args, name=name) self.type = 'images' @staticmethod def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result def preprocess_image(self, image): if self.image_aspect_ratio == 'pad': image = self.expand2square(image, tuple(int(x *255) for x in self.vis_processor.image_mean)) if isinstance(self.vis_processor, CLIPImageProcessor) or isinstance(self.vis_processor, SiglipImageProcessor): image = self.vis_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] else: image = self.vis_processor.preprocess(image) elif self.image_aspect_ratio == "anyres": image = process_anyres_image(image, self.vis_processor, self.image_grid_pinpoints) else: if isinstance(self.vis_processor, CLIPImageProcessor) or isinstance(self.vis_processor, SiglipImageProcessor): image = self.vis_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] else: image = self.vis_processor.preprocess(image) return image def vis_preprocess(self, vis_path): image = Image.open(vis_path).convert('RGB') image = self.preprocess_image(image) if isinstance(image, list): images = image else: images = [image] return images class VideoTaskDataset(ImageTaskDataset): def __init__(self, anno_path=None, data_args=None, name=None): super().__init__(anno_path=anno_path, data_args=data_args, name=name) # if not specify num_segments, use default self.num_segments = self.data_args.num_segments self.sample_strategy = self.data_args.sample_strategy self.type = 'video' def vis_preprocess(self, vis_path): images = None try: video_reader = VideoReader(vis_path) vlen = len(video_reader) fps = video_reader.get_avg_fps() duration = vlen / float(fps) frame_indices = get_frame_indices(self.num_segments, vlen, sample=self.sample_strategy, input_fps=fps, pad_last=False) frames = video_reader.get_batch(frame_indices) frames = frames.numpy().astype(np.uint8) images = [Image.fromarray(frame).convert('RGB') for frame in frames] images = [self.preprocess_image(image) for image in images] except Exception as e: print(e, vis_path) sys.stdout.flush() images = None # print(f"images: {len(images)}, {images[0].shape}") return images class FramesTaskDataset(ImageTaskDataset): def __init__(self, anno_path=None, data_args=None, fps=0.5, name=None): super().__init__(anno_path=anno_path, data_args=data_args, name=name) # if not specify num_segments, use default self.num_segments = self.data_args.num_segments # print("self.num_segments:", self.num_segments) self.type = 'video' self.default_fps = 2.0 self.fps = fps @staticmethod def _downsample_frames(frames, interval, keep_first_last=True): if keep_first_last: first, last, mid = frames[0], frames[-1], frames[1:-1] sampled_frames = mid[interval - 1::interval] ret = [first] + sampled_frames + [last] else: # may output empty list, recommend keep first and last frame ret = frames[interval - 1::interval] return ret @staticmethod def _sample_frames(frames, num_segments): frame_indices = list(range(len(frames))) cand_indices = copy.deepcopy(frame_indices) intervals = np.linspace(start=0, stop=len(frame_indices), num=num_segments + 1).astype(int) ranges = [] for idx, interv in enumerate(intervals[:-1]): ranges.append((interv, intervals[idx + 1] - 1)) try: frame_indices = [cand_indices[random.choice(range(x[0], x[1]))] for x in ranges] except: frame_indices = [cand_indices[x[0]] for x in ranges] sampled_frames = [frames[indice] for indice in frame_indices] return sampled_frames def vis_preprocess(self, vis_path): image_files = [(os.path.splitext(img)[0], img) for img in os.listdir(vis_path) if not img.startswith('cuttime')] if image_files[0][1].endswith('jpeg'): # gpt4v public data image_files = [(int(x[0].split('_')[-1]), x[1]) for x in image_files] else: image_files = [(int(x[0]), x[1]) for x in image_files] image_files = sorted(image_files, key=lambda img: img[0]) if self.fps < self.default_fps: interval = math.floor(self.default_fps / self.fps) image_files = self._downsample_frames(image_files, interval, keep_first_last=True) if self.num_segments > 0 and len(image_files) > self.num_segments: image_files = self._sample_frames(image_files, self.num_segments) images = [] for image_file in image_files: try: images.append(Image.open(os.path.join(vis_path, image_file[1])).convert('RGB')) except Exception as e: continue formatted_images = [] for image in images: im = self.preprocess_image(image) if isinstance(im, list): formatted_images.extend(im) else: formatted_images.append(im) return formatted_images