Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from mmengine.hooks import Hook | |
from xtuner.registry import BUILDER | |
from xtuner.utils import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX | |
def split_list(lst, value): | |
res = [] | |
tmp_res = [] | |
for i in lst: | |
if i == value: | |
res.append(tmp_res) | |
tmp_res = [] | |
else: | |
tmp_res.append(i) | |
res.append(tmp_res) | |
return res | |
class DatasetInfoHook(Hook): | |
def __init__(self, tokenizer, is_intern_repo_dataset=False): | |
self.tokenizer = BUILDER.build(tokenizer) | |
self.is_intern_repo_dataset = is_intern_repo_dataset | |
def log(self, runner, dataset, mode='train'): | |
runner.logger.info(f'Num {mode} samples {len(dataset)}') | |
runner.logger.info(f'{mode} example:') | |
input_ids = dataset[0]['input_ids'] | |
if self.is_intern_repo_dataset: | |
input_ids = [abs(x) for x in input_ids] | |
# Try to split list to be compatible with IMAGE token | |
input_ids = split_list(input_ids, IMAGE_TOKEN_INDEX) | |
text = '' | |
for idx, ids in enumerate(input_ids): | |
text += self.tokenizer.decode(ids) | |
if idx != len(input_ids) - 1: | |
text += DEFAULT_IMAGE_TOKEN | |
runner.logger.info(text) | |
def before_train(self, runner) -> None: | |
do_train = runner.train_loop is not None | |
do_eval = runner.val_loop is not None | |
if do_train: | |
train_dataset = runner.train_dataloader.dataset | |
self.log(runner, train_dataset, mode='train') | |
if do_eval: | |
eval_dataset = runner.val_dataloader.dataset | |
self.log(runner, eval_dataset, mode='eval') | |
def before_val(self, runner) -> None: | |
eval_dataset = runner.val_dataloader.dataset | |
self.log(runner, eval_dataset, mode='eval') | |
def before_test(self, runner) -> None: | |
test_dataset = runner.test_dataloader.dataset | |
self.log(runner, test_dataset, mode='test') | |