Spaces:
Runtime error
Runtime error
File size: 6,151 Bytes
476ac07 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from functools import partial
import numpy as np
from datasets import DatasetDict
from mmengine.config import Config
from xtuner.dataset.utils import Packer, encode_fn
from xtuner.registry import BUILDER
def parse_args():
parser = argparse.ArgumentParser(
description='Verify the correctness of the config file for the '
'custom dataset.')
parser.add_argument('config', help='config file name or path.')
args = parser.parse_args()
return args
def is_standard_format(dataset):
example = next(iter(dataset))
if 'conversation' not in example:
return False
conversation = example['conversation']
if not isinstance(conversation, list):
return False
for item in conversation:
if (not isinstance(item, dict)) or ('input'
not in item) or ('output'
not in item):
return False
input, output = item['input'], item['output']
if (not isinstance(input, str)) or (not isinstance(output, str)):
return False
return True
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
tokenizer = BUILDER.build(cfg.tokenizer)
if cfg.get('framework', 'mmengine').lower() == 'huggingface':
train_dataset = cfg.train_dataset
else:
train_dataset = cfg.train_dataloader.dataset
dataset = train_dataset.dataset
max_length = train_dataset.max_length
dataset_map_fn = train_dataset.get('dataset_map_fn', None)
template_map_fn = train_dataset.get('template_map_fn', None)
max_dataset_length = train_dataset.get('max_dataset_length', 10)
split = train_dataset.get('split', 'train')
remove_unused_columns = train_dataset.get('remove_unused_columns', False)
rename_maps = train_dataset.get('rename_maps', [])
shuffle_before_pack = train_dataset.get('shuffle_before_pack', True)
pack_to_max_length = train_dataset.get('pack_to_max_length', True)
input_ids_with_output = train_dataset.get('input_ids_with_output', True)
if dataset.get('path', '') != 'json':
raise ValueError(
'You are using custom datasets for SFT. '
'The custom datasets should be in json format. To load your JSON '
'file, you can use the following code snippet: \n'
'"""\nfrom datasets import load_dataset \n'
'dataset = dict(type=load_dataset, path=\'json\', '
'data_files=\'your_json_file.json\')\n"""\n'
'For more details, please refer to Step 5 in the '
'`Using Custom Datasets` section of the documentation found at'
' docs/zh_cn/user_guides/single_turn_conversation.md.')
try:
dataset = BUILDER.build(dataset)
except RuntimeError:
raise RuntimeError(
'Unable to load the custom JSON file using '
'`datasets.load_dataset`. Your data-related config is '
f'{train_dataset}. Please refer to the official documentation on'
' `load_dataset` (https://huggingface.co/docs/datasets/loading) '
'for more details.')
if isinstance(dataset, DatasetDict):
dataset = dataset[split]
if not is_standard_format(dataset) and dataset_map_fn is None:
raise ValueError(
'If the custom dataset is not in the XTuner-defined '
'format, please utilize `dataset_map_fn` to map the original data'
' to the standard format. For more details, please refer to '
'Step 1 and Step 5 in the `Using Custom Datasets` section of the '
'documentation found at '
'`docs/zh_cn/user_guides/single_turn_conversation.md`.')
if is_standard_format(dataset) and dataset_map_fn is not None:
raise ValueError(
'If the custom dataset is already in the XTuner-defined format, '
'please set `dataset_map_fn` to None.'
'For more details, please refer to Step 1 and Step 5 in the '
'`Using Custom Datasets` section of the documentation found at'
' docs/zh_cn/user_guides/single_turn_conversation.md.')
max_dataset_length = min(max_dataset_length, len(dataset))
indices = np.random.choice(len(dataset), max_dataset_length, replace=False)
dataset = dataset.select(indices)
if dataset_map_fn is not None:
dataset = dataset.map(dataset_map_fn)
print('#' * 20 + ' dataset after `dataset_map_fn` ' + '#' * 20)
print(dataset[0]['conversation'])
if template_map_fn is not None:
template_map_fn = BUILDER.build(template_map_fn)
dataset = dataset.map(template_map_fn)
print('#' * 20 + ' dataset after adding templates ' + '#' * 20)
print(dataset[0]['conversation'])
for old, new in rename_maps:
dataset = dataset.rename_column(old, new)
if pack_to_max_length and (not remove_unused_columns):
raise ValueError('We have to remove unused columns if '
'`pack_to_max_length` is set to True.')
dataset = dataset.map(
partial(
encode_fn,
tokenizer=tokenizer,
max_length=max_length,
input_ids_with_output=input_ids_with_output),
remove_columns=list(dataset.column_names)
if remove_unused_columns else None)
print('#' * 20 + ' encoded input_ids ' + '#' * 20)
print(dataset[0]['input_ids'])
print('#' * 20 + ' encoded labels ' + '#' * 20)
print(dataset[0]['labels'])
if pack_to_max_length and split == 'train':
if shuffle_before_pack:
dataset = dataset.shuffle()
dataset = dataset.flatten_indices()
dataset = dataset.map(Packer(max_length), batched=True)
print('#' * 20 + ' input_ids after packed to max_length ' +
'#' * 20)
print(dataset[0]['input_ids'])
print('#' * 20 + ' labels after packed to max_length ' + '#' * 20)
print(dataset[0]['labels'])
if __name__ == '__main__':
main()
|