model1 / llava /datasets /base_dataset.py
multitensor's picture
Upload folder using huggingface_hub
bbfa6f6 verified
import os
import json
import sys
import copy
import math
import torch
import decord
import random
import numpy as np
from PIL import Image
from decord import VideoReader
from torch.utils.data import Dataset
from llava.utils import master_print
from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from transformers import CLIPImageProcessor, SiglipImageProcessor
from llava.mm_utils import get_frame_indices, process_anyres_image
from torch.utils.data.dataloader import default_collate
decord.bridge.set_bridge("torch")
class TaskBaseDataset(Dataset):
""" Implementation of base task dataset """
def __init__(self, anno_path=None, data_args=None, name=None, **kwargs):
self.anno_path = anno_path
self.data_args = data_args
self.image_aspect_ratio = data_args.image_aspect_ratio
self.image_grid_pinpoints = data_args.image_grid_pinpoints
self.vis_processor = data_args.image_processor
self.type = None
self.name = name
master_print(f"Loading dataset {name}...")
if (anno_path is not None):
if not hasattr(self, 'annotation'):
self.annotation = json.load(open(anno_path, 'r'))
master_print(f"Finish loading dataset {name} {len(self.annotation)} samples...")
def __len__(self):
return len(self.annotation)
def collater(self, samples):
return default_collate(samples)
def text_preprocess(self, sources) -> List[List[Dict[str, str]]]:
pass
def vis_preprocess(self, vis_path) -> Image:
pass
@property
def data_type(self):
return self.type
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
item = self.annotation[i]
vis_path = item['vis_path'] if 'vis_path' in item else item['video_path']
ret = {
'images': self.vis_preprocess(vis_path),
'conversations': self.text_preprocess(item)
}
if 'id' in item:
ret['id'] = item['id']
return ret
class ImageTaskDataset(TaskBaseDataset):
def __init__(self, anno_path=None, data_args=None, name=None):
super().__init__(anno_path=anno_path,
data_args=data_args,
name=name)
self.type = 'images'
@staticmethod
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def preprocess_image(self, image):
if self.image_aspect_ratio == 'pad':
image = self.expand2square(image, tuple(int(x *255) for x in self.vis_processor.image_mean))
if isinstance(self.vis_processor, CLIPImageProcessor) or isinstance(self.vis_processor, SiglipImageProcessor):
image = self.vis_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
else:
image = self.vis_processor.preprocess(image)
elif self.image_aspect_ratio == "anyres":
image = process_anyres_image(image, self.vis_processor, self.image_grid_pinpoints)
else:
if isinstance(self.vis_processor, CLIPImageProcessor) or isinstance(self.vis_processor, SiglipImageProcessor):
image = self.vis_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
else:
image = self.vis_processor.preprocess(image)
return image
def vis_preprocess(self, vis_path):
image = Image.open(vis_path).convert('RGB')
image = self.preprocess_image(image)
if isinstance(image, list):
images = image
else:
images = [image]
return images
class VideoTaskDataset(ImageTaskDataset):
def __init__(self, anno_path=None, data_args=None, name=None):
super().__init__(anno_path=anno_path,
data_args=data_args,
name=name)
# if not specify num_segments, use default
self.num_segments = self.data_args.num_segments
self.sample_strategy = self.data_args.sample_strategy
self.type = 'video'
def vis_preprocess(self, vis_path):
images = None
try:
video_reader = VideoReader(vis_path)
vlen = len(video_reader)
fps = video_reader.get_avg_fps()
duration = vlen / float(fps)
frame_indices = get_frame_indices(self.num_segments, vlen,
sample=self.sample_strategy, input_fps=fps, pad_last=False)
frames = video_reader.get_batch(frame_indices)
frames = frames.numpy().astype(np.uint8)
images = [Image.fromarray(frame).convert('RGB') for frame in frames]
images = [self.preprocess_image(image) for image in images]
except Exception as e:
print(e, vis_path)
sys.stdout.flush()
images = None
# print(f"images: {len(images)}, {images[0].shape}")
return images
class FramesTaskDataset(ImageTaskDataset):
def __init__(self, anno_path=None, data_args=None, fps=0.5, name=None):
super().__init__(anno_path=anno_path,
data_args=data_args,
name=name)
# if not specify num_segments, use default
self.num_segments = self.data_args.num_segments
# print("self.num_segments:", self.num_segments)
self.type = 'video'
self.default_fps = 2.0
self.fps = fps
@staticmethod
def _downsample_frames(frames, interval, keep_first_last=True):
if keep_first_last:
first, last, mid = frames[0], frames[-1], frames[1:-1]
sampled_frames = mid[interval - 1::interval]
ret = [first] + sampled_frames + [last]
else:
# may output empty list, recommend keep first and last frame
ret = frames[interval - 1::interval]
return ret
@staticmethod
def _sample_frames(frames, num_segments):
frame_indices = list(range(len(frames)))
cand_indices = copy.deepcopy(frame_indices)
intervals = np.linspace(start=0, stop=len(frame_indices), num=num_segments + 1).astype(int)
ranges = []
for idx, interv in enumerate(intervals[:-1]):
ranges.append((interv, intervals[idx + 1] - 1))
try:
frame_indices = [cand_indices[random.choice(range(x[0], x[1]))] for x in ranges]
except:
frame_indices = [cand_indices[x[0]] for x in ranges]
sampled_frames = [frames[indice] for indice in frame_indices]
return sampled_frames
def vis_preprocess(self, vis_path):
image_files = [(os.path.splitext(img)[0], img) for img in os.listdir(vis_path) if not img.startswith('cuttime')]
if image_files[0][1].endswith('jpeg'):
# gpt4v public data
image_files = [(int(x[0].split('_')[-1]), x[1]) for x in image_files]
else:
image_files = [(int(x[0]), x[1]) for x in image_files]
image_files = sorted(image_files, key=lambda img: img[0])
if self.fps < self.default_fps:
interval = math.floor(self.default_fps / self.fps)
image_files = self._downsample_frames(image_files, interval, keep_first_last=True)
if self.num_segments > 0 and len(image_files) > self.num_segments:
image_files = self._sample_frames(image_files, self.num_segments)
images = []
for image_file in image_files:
try:
images.append(Image.open(os.path.join(vis_path, image_file[1])).convert('RGB'))
except Exception as e:
continue
formatted_images = []
for image in images:
im = self.preprocess_image(image)
if isinstance(im, list):
formatted_images.extend(im)
else:
formatted_images.append(im)
return formatted_images