|
import os |
|
import json |
|
import sys |
|
import copy |
|
import math |
|
import torch |
|
import decord |
|
import random |
|
import numpy as np |
|
from PIL import Image |
|
from decord import VideoReader |
|
from torch.utils.data import Dataset |
|
from llava.utils import master_print |
|
from typing import Dict, Optional, Sequence, List |
|
from llava.datasets.data_cfgs import data_configs |
|
from transformers import CLIPImageProcessor, SiglipImageProcessor |
|
|
|
from llava.mm_utils import get_frame_indices, process_anyres_image |
|
from torch.utils.data.dataloader import default_collate |
|
|
|
decord.bridge.set_bridge("torch") |
|
|
|
class TaskBaseDataset(Dataset): |
|
""" Implementation of base task dataset """ |
|
def __init__(self, anno_path=None, data_args=None, name=None, **kwargs): |
|
|
|
self.anno_path = anno_path |
|
self.data_args = data_args |
|
self.image_aspect_ratio = data_args.image_aspect_ratio |
|
self.image_grid_pinpoints = data_args.image_grid_pinpoints |
|
self.vis_processor = data_args.image_processor |
|
self.type = None |
|
self.name = name |
|
|
|
master_print(f"Loading dataset {name}...") |
|
if (anno_path is not None): |
|
if not hasattr(self, 'annotation'): |
|
self.annotation = json.load(open(anno_path, 'r')) |
|
master_print(f"Finish loading dataset {name} {len(self.annotation)} samples...") |
|
|
|
def __len__(self): |
|
return len(self.annotation) |
|
|
|
def collater(self, samples): |
|
return default_collate(samples) |
|
|
|
def text_preprocess(self, sources) -> List[List[Dict[str, str]]]: |
|
pass |
|
|
|
def vis_preprocess(self, vis_path) -> Image: |
|
pass |
|
|
|
@property |
|
def data_type(self): |
|
return self.type |
|
|
|
def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
|
item = self.annotation[i] |
|
|
|
vis_path = item['vis_path'] if 'vis_path' in item else item['video_path'] |
|
|
|
ret = { |
|
'images': self.vis_preprocess(vis_path), |
|
'conversations': self.text_preprocess(item) |
|
} |
|
if 'id' in item: |
|
ret['id'] = item['id'] |
|
|
|
return ret |
|
|
|
|
|
class ImageTaskDataset(TaskBaseDataset): |
|
def __init__(self, anno_path=None, data_args=None, name=None): |
|
super().__init__(anno_path=anno_path, |
|
data_args=data_args, |
|
name=name) |
|
self.type = 'images' |
|
|
|
@staticmethod |
|
def expand2square(pil_img, background_color): |
|
width, height = pil_img.size |
|
if width == height: |
|
return pil_img |
|
elif width > height: |
|
result = Image.new(pil_img.mode, (width, width), background_color) |
|
result.paste(pil_img, (0, (width - height) // 2)) |
|
return result |
|
else: |
|
result = Image.new(pil_img.mode, (height, height), background_color) |
|
result.paste(pil_img, ((height - width) // 2, 0)) |
|
return result |
|
|
|
def preprocess_image(self, image): |
|
if self.image_aspect_ratio == 'pad': |
|
image = self.expand2square(image, tuple(int(x *255) for x in self.vis_processor.image_mean)) |
|
if isinstance(self.vis_processor, CLIPImageProcessor) or isinstance(self.vis_processor, SiglipImageProcessor): |
|
image = self.vis_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] |
|
else: |
|
image = self.vis_processor.preprocess(image) |
|
elif self.image_aspect_ratio == "anyres": |
|
image = process_anyres_image(image, self.vis_processor, self.image_grid_pinpoints) |
|
else: |
|
if isinstance(self.vis_processor, CLIPImageProcessor) or isinstance(self.vis_processor, SiglipImageProcessor): |
|
image = self.vis_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] |
|
else: |
|
image = self.vis_processor.preprocess(image) |
|
|
|
return image |
|
|
|
def vis_preprocess(self, vis_path): |
|
image = Image.open(vis_path).convert('RGB') |
|
image = self.preprocess_image(image) |
|
if isinstance(image, list): |
|
images = image |
|
else: |
|
images = [image] |
|
|
|
return images |
|
|
|
|
|
class VideoTaskDataset(ImageTaskDataset): |
|
def __init__(self, anno_path=None, data_args=None, name=None): |
|
super().__init__(anno_path=anno_path, |
|
data_args=data_args, |
|
name=name) |
|
|
|
|
|
self.num_segments = self.data_args.num_segments |
|
self.sample_strategy = self.data_args.sample_strategy |
|
self.type = 'video' |
|
|
|
def vis_preprocess(self, vis_path): |
|
images = None |
|
try: |
|
video_reader = VideoReader(vis_path) |
|
vlen = len(video_reader) |
|
fps = video_reader.get_avg_fps() |
|
duration = vlen / float(fps) |
|
|
|
frame_indices = get_frame_indices(self.num_segments, vlen, |
|
sample=self.sample_strategy, input_fps=fps, pad_last=False) |
|
frames = video_reader.get_batch(frame_indices) |
|
frames = frames.numpy().astype(np.uint8) |
|
images = [Image.fromarray(frame).convert('RGB') for frame in frames] |
|
images = [self.preprocess_image(image) for image in images] |
|
except Exception as e: |
|
print(e, vis_path) |
|
sys.stdout.flush() |
|
images = None |
|
|
|
|
|
|
|
return images |
|
|
|
|
|
class FramesTaskDataset(ImageTaskDataset): |
|
def __init__(self, anno_path=None, data_args=None, fps=0.5, name=None): |
|
super().__init__(anno_path=anno_path, |
|
data_args=data_args, |
|
name=name) |
|
|
|
|
|
self.num_segments = self.data_args.num_segments |
|
|
|
self.type = 'video' |
|
self.default_fps = 2.0 |
|
self.fps = fps |
|
|
|
@staticmethod |
|
def _downsample_frames(frames, interval, keep_first_last=True): |
|
if keep_first_last: |
|
first, last, mid = frames[0], frames[-1], frames[1:-1] |
|
sampled_frames = mid[interval - 1::interval] |
|
ret = [first] + sampled_frames + [last] |
|
|
|
else: |
|
|
|
ret = frames[interval - 1::interval] |
|
|
|
return ret |
|
|
|
@staticmethod |
|
def _sample_frames(frames, num_segments): |
|
frame_indices = list(range(len(frames))) |
|
cand_indices = copy.deepcopy(frame_indices) |
|
intervals = np.linspace(start=0, stop=len(frame_indices), num=num_segments + 1).astype(int) |
|
ranges = [] |
|
|
|
for idx, interv in enumerate(intervals[:-1]): |
|
ranges.append((interv, intervals[idx + 1] - 1)) |
|
|
|
try: |
|
frame_indices = [cand_indices[random.choice(range(x[0], x[1]))] for x in ranges] |
|
except: |
|
frame_indices = [cand_indices[x[0]] for x in ranges] |
|
|
|
sampled_frames = [frames[indice] for indice in frame_indices] |
|
|
|
return sampled_frames |
|
|
|
def vis_preprocess(self, vis_path): |
|
image_files = [(os.path.splitext(img)[0], img) for img in os.listdir(vis_path) if not img.startswith('cuttime')] |
|
if image_files[0][1].endswith('jpeg'): |
|
|
|
image_files = [(int(x[0].split('_')[-1]), x[1]) for x in image_files] |
|
else: |
|
image_files = [(int(x[0]), x[1]) for x in image_files] |
|
|
|
image_files = sorted(image_files, key=lambda img: img[0]) |
|
|
|
if self.fps < self.default_fps: |
|
interval = math.floor(self.default_fps / self.fps) |
|
image_files = self._downsample_frames(image_files, interval, keep_first_last=True) |
|
|
|
if self.num_segments > 0 and len(image_files) > self.num_segments: |
|
image_files = self._sample_frames(image_files, self.num_segments) |
|
|
|
images = [] |
|
for image_file in image_files: |
|
try: |
|
images.append(Image.open(os.path.join(vis_path, image_file[1])).convert('RGB')) |
|
except Exception as e: |
|
continue |
|
formatted_images = [] |
|
for image in images: |
|
im = self.preprocess_image(image) |
|
if isinstance(im, list): |
|
formatted_images.extend(im) |
|
else: |
|
formatted_images.append(im) |
|
return formatted_images |
|
|
|
|
|
|
|
|