OMG-LLaVA / xtuner /engine /hooks /dataset_info_hook.py
zhangtao-whu's picture
Upload folder using huggingface_hub
476ac07 verified
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.hooks import Hook
from xtuner.registry import BUILDER
from xtuner.utils import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
def split_list(lst, value):
res = []
tmp_res = []
for i in lst:
if i == value:
res.append(tmp_res)
tmp_res = []
else:
tmp_res.append(i)
res.append(tmp_res)
return res
class DatasetInfoHook(Hook):
def __init__(self, tokenizer, is_intern_repo_dataset=False):
self.tokenizer = BUILDER.build(tokenizer)
self.is_intern_repo_dataset = is_intern_repo_dataset
def log(self, runner, dataset, mode='train'):
runner.logger.info(f'Num {mode} samples {len(dataset)}')
runner.logger.info(f'{mode} example:')
input_ids = dataset[0]['input_ids']
if self.is_intern_repo_dataset:
input_ids = [abs(x) for x in input_ids]
# Try to split list to be compatible with IMAGE token
input_ids = split_list(input_ids, IMAGE_TOKEN_INDEX)
text = ''
for idx, ids in enumerate(input_ids):
text += self.tokenizer.decode(ids)
if idx != len(input_ids) - 1:
text += DEFAULT_IMAGE_TOKEN
runner.logger.info(text)
def before_train(self, runner) -> None:
do_train = runner.train_loop is not None
do_eval = runner.val_loop is not None
if do_train:
train_dataset = runner.train_dataloader.dataset
self.log(runner, train_dataset, mode='train')
if do_eval:
eval_dataset = runner.val_dataloader.dataset
self.log(runner, eval_dataset, mode='eval')
def before_val(self, runner) -> None:
eval_dataset = runner.val_dataloader.dataset
self.log(runner, eval_dataset, mode='eval')
def before_test(self, runner) -> None:
test_dataset = runner.test_dataloader.dataset
self.log(runner, test_dataset, mode='test')