fffiloni's picture
Migrated from GitHub
d59f323 verified
import logging
import os
from typing import Literal
import torch
from datasets import Dataset as HFDataset
from datasets import DatasetDict, load_from_disk
from mmengine import print_log
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
from xtuner.registry import BUILDER
from xtuner.dataset.huggingface import build_origin_dataset
import copy
from .encode_fn import video_lisa_encode_fn
import json
import cv2
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from decord import VideoReader, cpu
def _get_rawvideo_dec(video_path, select_frames=5):
if os.path.exists(video_path):
vreader = VideoReader(video_path, ctx=cpu(0))
elif os.path.exists(video_path.replace('mkv', 'mp4')):
vreader = VideoReader(video_path.replace('mkv', 'mp4'), ctx=cpu(0))
else:
print(video_path)
raise FileNotFoundError
fps = vreader.get_avg_fps()
f_start = 0
f_end = len(vreader) - 1
num_frames = f_end - f_start + 1
assert num_frames > 0, f'num_frames: {num_frames}, f_start: {f_start}, f_end: {f_end}, fps: {fps}, video_path: {video_path}'
# T x 3 x H x W
if num_frames <= select_frames:
sample_pos = range(f_start, f_end + 1)
else:
split_point = np.linspace(0, num_frames, num=select_frames+1, dtype=int)
sample_pos = [np.random.randint(split_point[i], split_point[i+1]) for i in range(select_frames)]
patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()]
return patch_images
class VideoChatUniViDataset(Dataset):
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
IMG_START_TOKEN = '<img>'
IMG_END_TOKEN = '</img>'
FAST_IMG_CONTEXT_TOKEN = '<FAST_IMG_CONTEXT>'
FAST_IMG_START_TOKEN = '<fast_img>'
FAST_IMG_END_TOKEN = '</fast_img>'
def __init__(self,
image_folder,
json_file,
extra_image_processor=None,
tokenizer=None,
sampled_frames=10,
offline_processed_text_folder=None,
template_map_fn=None,
max_length=2048,
lazy=True,
repeats=1,
special_tokens=None,
use_fast=False,
n_fast_images=50,
fast_pool_size=4,
arch_type: Literal['intern_vl', 'qwen'] = 'intern_vl',
preprocessor=None,
):
assert lazy is True
self.tokenizer = BUILDER.build(tokenizer)
self.sampled_frames = sampled_frames
assert offline_processed_text_folder or (json_file and tokenizer)
self.lazy = lazy
self.max_length = max_length
self.template_map_fn = template_map_fn
if isinstance(self.template_map_fn, dict) and self.lazy:
_type = self.template_map_fn['type']
del self.template_map_fn['type']
self.template_map_fn = _type(**self.template_map_fn)
if offline_processed_text_folder and json_file:
print_log(
'Both `offline_processed_text_folder` and '
'`data_path` are set, and we load dataset from'
'`offline_processed_text_folder` '
f'({offline_processed_text_folder})',
logger='current',
level=logging.WARNING)
if offline_processed_text_folder is not None:
raise NotImplementedError
else:
json_datas = self.json_file_preprocess(json_file)
self.json_datas = json_datas
json_data = DatasetDict({'train': HFDataset.from_list(json_datas)})
if self.lazy:
self.text_data = build_origin_dataset(json_data, 'train')
else:
raise NotImplementedError
self.image_folder = image_folder
if extra_image_processor is not None:
self.extra_image_processor = BUILDER.build(extra_image_processor)
self.arch_type = arch_type
if self.arch_type == 'qwen':
self.IMG_CONTEXT_TOKEN = '<|image_pad|>'
self.IMG_START_TOKEN = '<|vision_start|>'
self.IMG_END_TOKEN = '<|vision_end|>'
elif self.arch_type == 'llava':
self.IMG_CONTEXT_TOKEN = '<image>'
self.IMG_START_TOKEN = ''
self.IMG_END_TOKEN = ''
self.repeats = repeats
self._system = ''
self.downsample_ratio = 0.5
if self.arch_type == 'llava':
self.downsample_ratio = 1
self.image_size = 448
if self.arch_type == 'llava':
self.image_size = 336
patch_size = 14
self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
if self.arch_type == 'qwen':
self.patch_token = 1
if preprocessor is None:
self.transformer = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
])
self.preprocessor = None
else:
self.transformer = None
self.preprocessor = BUILDER.build(preprocessor)
self.arch_type = arch_type
if special_tokens is not None:
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
self.use_fast = use_fast
self.n_fast_images = n_fast_images
self.fast_pool_size = fast_pool_size
# for visualization debug
self.save_folder = './work_dirs/video_debug/'
self.cur_number = 0
print("Video Chat dataset, include {} items.".format(len(self.text_data)))
def __len__(self):
return len(self.text_data) * self.repeats
@property
def modality_length(self):
length_list = []
for data_dict in self.text_data:
cur_len = 10000
length_list.append(cur_len)
return length_list
def real_len(self):
return len(self.text_data)
def json_file_preprocess(self, json_file):
# prepare expression annotation files
with open(json_file, 'r') as f:
json_datas = json.load(f)
return json_datas
def dataset_map_fn(self, data_dict, select_k=5):
assert 'video' in data_dict
# video
video_file = data_dict['video']
video_file = os.path.join(self.image_folder, video_file)
images = _get_rawvideo_dec(video_file, select_frames=select_k)
if self.use_fast:
fast_images = _get_rawvideo_dec(video_file, select_frames=self.n_fast_images)
else:
fast_images = None
conversation = data_dict['conversations']
# prepare text
if self.use_fast:
text_dict = self.prepare_text(
select_k, conversation, num_image_tokens=self.patch_token,
n_fast_images=len(fast_images),
)
else:
text_dict = self.prepare_text(
select_k, conversation, num_image_tokens=self.patch_token,
)
ret = {'images': images, 'conversation': text_dict['conversation'], 'fast_images': fast_images}
return ret
def prepare_text(self, n_frames, conversation, num_image_tokens=256, n_fast_images=0):
if self.use_fast:
fast_frame_token_str = f'{self.FAST_IMG_START_TOKEN}' \
f'{self.FAST_IMG_CONTEXT_TOKEN * n_fast_images * self.fast_pool_size * self.fast_pool_size}' \
f'{self.FAST_IMG_END_TOKEN}' + '\n'
else:
fast_frame_token_str = ''
frame_token_str = f'{self.IMG_START_TOKEN}' \
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
f'{self.IMG_END_TOKEN}'
questions = []
answers = []
for conv in conversation:
if conv['from'] == 'human':
questions.append(conv['value'].replace('<image>', ''))
else:
answers.append(conv['value'])
assert len(questions) == len(answers)
qa_list = []
for i, (question, answer) in enumerate(zip(questions, answers)):
if i == 0:
frame_tokens = frame_token_str + '\n'
# frame_tokens = '=' + ' '
frame_tokens = frame_tokens * n_frames
frame_tokens = frame_tokens.strip()
frame_tokens = fast_frame_token_str + frame_tokens
qa_list.append(
{'from': 'human', 'value': frame_tokens + question}
)
else:
qa_list.append(
{'from': 'human', 'value': question}
)
qa_list.append(
{'from': 'gpt', 'value': answer}
)
input = ''
conversation = []
for msg in qa_list:
if msg['from'] == 'human':
input += msg['value']
elif msg['from'] == 'gpt':
conversation.append({'input': input, 'output': msg['value']})
input = ''
else:
raise NotImplementedError
# add system information
conversation[0].update({'system': self._system})
return {'conversation': conversation}
def __getitem__(self, index):
index = index % self.real_len()
selected_data_dict = copy.deepcopy(self.text_data[index])
data_dict = self.dataset_map_fn(selected_data_dict, select_k=self.sampled_frames)
assert 'images' in data_dict.keys()
if self.use_fast:
assert 'fast_images' in data_dict.keys()
pixel_values = []
num_video_tokens = None
num_frame_tokens = None
if data_dict.get('images', None) is not None:
frames_files = data_dict['images']
for frame_image in frames_files:
frame_image = frame_image.convert('RGB')
ori_width, ori_height = frame_image.size
if self.preprocessor is not None:
pass
else:
frame_image = self.transformer(frame_image)
pixel_values.append(frame_image)
if self.preprocessor is not None:
if self.arch_type == 'qwen':
_data_dict = self.preprocessor(pixel_values, do_resize=True, size=(self.image_size, self.image_size))
_data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
_data_dict['image_grid_thw'] = torch.tensor(_data_dict['image_grid_thw'], dtype=torch.int)
num_frame_tokens = int(_data_dict['image_grid_thw'][0].prod() * (self.downsample_ratio ** 2))
num_frames = _data_dict['image_grid_thw'].shape[0]
num_video_tokens = num_frame_tokens * num_frames
elif self.arch_type == 'llava':
_data_dict = self.preprocessor(pixel_values, do_resize=True,
size=(self.image_size, self.image_size))
_data_dict['pixel_values'] = np.stack(_data_dict['pixel_values'], axis=0)
_data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
else:
raise NotImplementedError
data_dict.update(_data_dict)
else:
pixel_values = torch.stack(pixel_values, dim=0) # (n_f, 3, h, w)
data_dict['pixel_values'] = pixel_values
else:
data_dict['pixel_values'] = torch.zeros(0, 3, self.image_size, self.image_size)
data_dict['masks'] = None
if num_video_tokens is not None:
assert self.patch_token == 1
input_str = data_dict['conversation'][0]['input']
input_str = input_str.replace(self.IMG_CONTEXT_TOKEN, self.IMG_CONTEXT_TOKEN * num_frame_tokens)
assert input_str.count(self.IMG_CONTEXT_TOKEN) == num_video_tokens
data_dict['conversation'][0]['input'] = input_str
result = self.template_map_fn(data_dict)
data_dict.update(result)
result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, with_image_token=True)
data_dict.update(result)
# for fast branch
if self.use_fast:
fast_pixel_values = []
frames_files = data_dict['fast_images']
for frame_image in frames_files:
frame_image = frame_image.convert('RGB')
ori_width, ori_height = frame_image.size
frame_image = self.transformer(frame_image)
fast_pixel_values.append(frame_image)
fast_pixel_values = torch.stack(fast_pixel_values, dim=0) # (n_f, 3, h, w)
data_dict['fast_pixel_values'] = fast_pixel_values
# # for debug
# self.visualization_debug(data_dict)
# if self.cur_number < 10:
# return self[random.randint(0, len(self))]
data_dict['type'] = 'video'
return data_dict
def visualization_debug(self, data_dict):
save_folder = os.path.join(self.save_folder, 'sample_{}'.format(self.cur_number))
if not os.path.exists(save_folder):
os.mkdir(save_folder)
self.cur_number += 1
# images
show_images = []
pixel_values = data_dict['pixel_values']
save_folder_image = os.path.join(save_folder, 'image')
if not os.path.exists(save_folder_image):
os.mkdir(save_folder_image)
for i_image, image_pixel_value in enumerate(pixel_values):
# print(image_pixel_value.shape)
image_pixel_value[0] = image_pixel_value[0] * 0.2686
image_pixel_value[1] = image_pixel_value[1] * 0.2613
image_pixel_value[2] = image_pixel_value[2] * 0.2757
image_pixel_value[0] = image_pixel_value[0] + 0.4814
image_pixel_value[1] = image_pixel_value[1] + 0.4578
image_pixel_value[2] = image_pixel_value[2] + 0.4082
image_pixel_value = image_pixel_value * 255
image_pixel_value = image_pixel_value.permute(1, 2, 0)
image_pixel_value = image_pixel_value.to(torch.uint8).numpy()
# print(os.path.join(save_folder_image, '{}.jpg'.format(i_image)))
# print(image_pixel_value.shape)
show_images.append(image_pixel_value)
cv2.imwrite(os.path.join(save_folder_image, '{}.jpg'.format(i_image)), image_pixel_value)
# text
input_text = self.tokenizer.decode(data_dict['input_ids'], skip_special_tokens=False)
with open(os.path.join(save_folder, 'text.json'), 'w') as f:
json.dump([input_text], f)
return