|
|
|
import logging |
|
import os |
|
from datetime import timedelta |
|
from functools import partial |
|
|
|
import numpy as np |
|
from datasets import DatasetDict, concatenate_datasets |
|
from mmengine import print_log |
|
from mmengine.config import Config, ConfigDict |
|
from mmengine.utils.misc import get_object_from_string |
|
from torch import distributed as dist |
|
|
|
from xtuner.registry import BUILDER, MAP_FUNC |
|
from .utils import Packer, encode_fn |
|
|
|
|
|
def get_lengths(example): |
|
return {'length': len(example['input_ids'])} |
|
|
|
|
|
def build_origin_dataset(dataset, split): |
|
if isinstance(dataset, DatasetDict): |
|
if split is None: |
|
dataset = concatenate_datasets(dataset.values()) |
|
else: |
|
dataset = dataset[split] |
|
elif isinstance(dataset, dict) or isinstance( |
|
dataset, Config) or isinstance(dataset, ConfigDict): |
|
dataset = BUILDER.build(dataset) |
|
if isinstance(dataset, DatasetDict): |
|
if split is None: |
|
dataset = concatenate_datasets(dataset.values()) |
|
else: |
|
dataset = dataset[split] |
|
return dataset |
|
|
|
|
|
def map_dataset(dataset, dataset_map_fn, map_num_proc): |
|
if isinstance(dataset_map_fn, str): |
|
map_fn_obj = MAP_FUNC.get(dataset_map_fn) or get_object_from_string( |
|
dataset_map_fn) |
|
if map_fn_obj is not None: |
|
dataset_map_fn = map_fn_obj |
|
else: |
|
raise TypeError('dataset_map_fn must be a function or a ' |
|
"registered function's string in MAP_FUNC, " |
|
f"but got a string of '{dataset_map_fn}'") |
|
|
|
dataset = dataset.map(dataset_map_fn, num_proc=map_num_proc) |
|
return dataset |
|
|
|
|
|
def add_template_to_dataset(dataset, template_map_fn, map_num_proc): |
|
if isinstance(template_map_fn, |
|
dict) or isinstance(template_map_fn, Config) or isinstance( |
|
template_map_fn, ConfigDict): |
|
template_map_fn = BUILDER.build(template_map_fn) |
|
dataset = dataset.map(template_map_fn, num_proc=map_num_proc) |
|
|
|
dataset = dataset.filter( |
|
lambda example: len(example['conversation']) > 0, |
|
num_proc=map_num_proc) |
|
return dataset |
|
|
|
|
|
def tokenize_dataset(dataset, tokenizer, max_length, with_image_token, |
|
input_ids_with_output, remove_unused_columns, |
|
map_num_proc): |
|
assert (tokenizer is not None) and (max_length is not None), \ |
|
f'({tokenizer}, {max_length})' |
|
if isinstance(tokenizer, dict) or isinstance( |
|
tokenizer, Config) or isinstance(tokenizer, ConfigDict): |
|
tokenizer = BUILDER.build(tokenizer) |
|
dataset = dataset.map( |
|
partial( |
|
encode_fn, |
|
tokenizer=tokenizer, |
|
max_length=max_length, |
|
with_image_token=with_image_token, |
|
input_ids_with_output=input_ids_with_output), |
|
remove_columns=list(dataset.column_names) |
|
if remove_unused_columns else None, |
|
num_proc=map_num_proc) |
|
return dataset |
|
|
|
|
|
def pack_dataset(dataset, max_length, use_varlen_attn, shuffle_before_pack, |
|
map_num_proc): |
|
if shuffle_before_pack: |
|
dataset = dataset.shuffle() |
|
dataset = dataset.flatten_indices(num_proc=map_num_proc) |
|
dataset = dataset.map( |
|
Packer(max_length, use_varlen_attn=use_varlen_attn), |
|
batched=True, |
|
num_proc=map_num_proc) |
|
return dataset |
|
|
|
|
|
def process(dataset, |
|
do_dataset_tokenization=True, |
|
tokenizer=None, |
|
max_length=None, |
|
dataset_map_fn=None, |
|
template_map_fn=None, |
|
max_dataset_length=None, |
|
split='train', |
|
remove_unused_columns=False, |
|
rename_maps=[], |
|
shuffle_before_pack=True, |
|
pack_to_max_length=True, |
|
use_varlen_attn=False, |
|
input_ids_with_output=True, |
|
with_image_token=False, |
|
map_num_proc=32): |
|
"""Post-process the dataset loaded from the Hugging Face Hub, or a local |
|
dataset. |
|
|
|
Args: |
|
dataset: The dataset to be post-processed. |
|
do_dataset_tokenization: Whether the dataset need to be tokenized |
|
in this function. Default to True. |
|
tokenizer: The tokenizer processes some raw text as input and outputs |
|
an Encoding. If `do_dataset_tokenization` is True, this argument |
|
should not be None. Default to None. |
|
max_length: Max length of the sequence. If `do_dataset_tokenization` |
|
or `pack_to_max_length` is True, this argument should not be None. |
|
Default to None. |
|
dataset_map_fn: Map the original dataset format to the one defined |
|
by xTuner. |
|
template_map_fn: Add the prompt template to the dataset |
|
max_dataset_length: If the length of the dataset is too long, we can |
|
randomly extract `max_dataset_length` from it. |
|
split: Which split of the data to load. |
|
If `None`, will return a single concatenated dataset with all |
|
splits (typically `datasets.Split.TRAIN` and |
|
`datasets.Split.TEST`). |
|
If given, will return a single Dataset. |
|
remove_unused_columns: Whether to remove columns from the dataset |
|
that are not used during training. |
|
rename_maps: Rename the column name of the dataset. |
|
shuffle_before_pack: Whether to shuffle the dataset before |
|
packing them. |
|
pack_to_max_length: Whether to pack the dataset to the `max_length `. |
|
This usually improves gpu utilization and therefore reduces |
|
training time. |
|
use_varlen_attn: If use_varlen_attn is True, we calculate attention |
|
the actual length of the sequence rather than the actual length |
|
of the sequence |
|
input_ids_with_output: Whether to put the groundtruth output |
|
corresponding to the question into the dataset. Typically set |
|
it to True during training and False during testing. |
|
with_image_token: Whether to convert DEFAULT_IMAGE_TOKEN to |
|
IMAGE_TOKEN_INDEX. Typically set it to True during the training |
|
of VLM. |
|
map_num_proc: Max number of processes when mapping the dataset. |
|
""" |
|
if use_varlen_attn: |
|
assert pack_to_max_length, \ |
|
'`pack_to_max_length` in `process_hf_dataset` should be set to ' \ |
|
'True if `use_varlen_attn` is True.' |
|
if pack_to_max_length: |
|
assert split == 'train' or split is None, \ |
|
('`split` should be `train` or `None` if `pack_to_max_length` is ' |
|
f'True, but got {split}.') |
|
|
|
dataset = build_origin_dataset(dataset, split) |
|
|
|
|
|
|
|
if max_dataset_length is not None: |
|
max_dataset_length = min(max_dataset_length, len(dataset)) |
|
indices = np.random.choice( |
|
len(dataset), max_dataset_length, replace=False) |
|
dataset = dataset.select(indices) |
|
|
|
|
|
if dataset_map_fn is not None: |
|
dataset = map_dataset(dataset, dataset_map_fn, map_num_proc) |
|
|
|
|
|
if template_map_fn is not None: |
|
dataset = add_template_to_dataset(dataset, template_map_fn, |
|
map_num_proc) |
|
|
|
for old, new in rename_maps: |
|
dataset = dataset.rename_column(old, new) |
|
|
|
|
|
if pack_to_max_length and (not remove_unused_columns): |
|
print_log( |
|
'We have to remove unused columns if ' |
|
'`pack_to_max_length` is set to True.', |
|
logger='current', |
|
level=logging.WARNING) |
|
remove_unused_columns = True |
|
|
|
if do_dataset_tokenization: |
|
dataset = tokenize_dataset(dataset, tokenizer, max_length, |
|
with_image_token, input_ids_with_output, |
|
remove_unused_columns, map_num_proc) |
|
else: |
|
assert {'input_ids', 'labels'}.issubset(dataset.column_names) |
|
|
|
if input_ids_with_output: |
|
|
|
dataset = dataset.filter( |
|
lambda example: any(label >= 0 for label in example['labels']), |
|
num_proc=map_num_proc) |
|
|
|
|
|
if pack_to_max_length: |
|
dataset = pack_dataset(dataset, max_length, use_varlen_attn, |
|
shuffle_before_pack, map_num_proc) |
|
|
|
|
|
dataset = dataset.map(get_lengths, num_proc=map_num_proc) |
|
setattr(dataset, 'length', dataset['length']) |
|
|
|
return dataset |
|
|
|
|
|
def process_hf_dataset(dataset, |
|
do_dataset_tokenization=True, |
|
tokenizer=None, |
|
max_length=None, |
|
dataset_map_fn=None, |
|
template_map_fn=None, |
|
max_dataset_length=None, |
|
split='train', |
|
remove_unused_columns=False, |
|
rename_maps=[], |
|
shuffle_before_pack=True, |
|
pack_to_max_length=True, |
|
use_varlen_attn=False, |
|
input_ids_with_output=True, |
|
with_image_token=False, |
|
map_num_proc=4): |
|
"""Post-process the dataset loaded from the Hugging Face Hub, or a local |
|
dataset. |
|
|
|
Args: |
|
dataset: The dataset to be post-processed. |
|
do_dataset_tokenization: Whether the dataset need to be tokenized |
|
in this function. Default to True. |
|
tokenizer: The tokenizer processes some raw text as input and outputs |
|
an Encoding. If `do_dataset_tokenization` is True, this argument |
|
should not be None. Default to None. |
|
max_length: Max length of the sequence. If `do_dataset_tokenization` |
|
or `pack_to_max_length` is True, this argument should not be None. |
|
Default to None. |
|
dataset_map_fn: Map the original dataset format to the one defined |
|
by xTuner. |
|
template_map_fn: Add the prompt template to the dataset |
|
max_dataset_length: If the length of the dataset is too long, we can |
|
randomly extract `max_dataset_length` from it. |
|
split: Which split of the data to load. |
|
If `None`, will return a single concatenated dataset with all |
|
splits (typically `datasets.Split.TRAIN` and |
|
`datasets.Split.TEST`). |
|
If given, will return a single Dataset. |
|
remove_unused_columns: Whether to remove columns from the dataset |
|
that are not used during training. |
|
rename_maps: Rename the column name of the dataset. |
|
shuffle_before_pack: Whether to shuffle the dataset before |
|
packing them. |
|
pack_to_max_length: Whether to pack the dataset to the `max_length `. |
|
This usually improves gpu utilization and therefore reduces |
|
training time. |
|
use_varlen_attn: If use_varlen_attn is True, we calculate attention |
|
the actual length of the sequence rather than the actual length |
|
of the sequence |
|
input_ids_with_output: Whether to put the groundtruth output |
|
corresponding to the question into the dataset. Typically set |
|
it to True during training and False during testing. |
|
with_image_token: Whether to convert DEFAULT_IMAGE_TOKEN to |
|
IMAGE_TOKEN_INDEX. Typically set it to True during the training |
|
of VLM. |
|
map_num_proc: Max number of processes when mapping the dataset. |
|
""" |
|
kwargs = dict( |
|
dataset=dataset, |
|
do_dataset_tokenization=do_dataset_tokenization, |
|
tokenizer=tokenizer, |
|
max_length=max_length, |
|
dataset_map_fn=dataset_map_fn, |
|
template_map_fn=template_map_fn, |
|
max_dataset_length=max_dataset_length, |
|
split=split, |
|
remove_unused_columns=remove_unused_columns, |
|
rename_maps=rename_maps, |
|
shuffle_before_pack=shuffle_before_pack, |
|
pack_to_max_length=pack_to_max_length, |
|
use_varlen_attn=use_varlen_attn, |
|
input_ids_with_output=input_ids_with_output, |
|
with_image_token=with_image_token, |
|
map_num_proc=map_num_proc) |
|
if not (dist.is_available() and dist.is_initialized()): |
|
return process(**kwargs) |
|
|
|
xtuner_dataset_timeout = timedelta( |
|
minutes=int(os.getenv('XTUNER_DATASET_TIMEOUT', default=30))) |
|
print_log( |
|
f'xtuner_dataset_timeout = {xtuner_dataset_timeout}', logger='current') |
|
|
|
group_gloo = dist.new_group(backend='gloo', timeout=xtuner_dataset_timeout) |
|
|
|
if dist.get_rank() == 0: |
|
dataset = process(**kwargs) |
|
objects = [dataset] |
|
else: |
|
objects = [None] |
|
|
|
dist.monitored_barrier(group=group_gloo, timeout=xtuner_dataset_timeout) |
|
dist.broadcast_object_list(objects, src=0) |
|
return objects[0] |
|
|