Spaces:
Runtime error
Runtime error
import os | |
from typing import TYPE_CHECKING, List, Union | |
from datasets import concatenate_datasets, interleave_datasets, load_dataset | |
from llmtuner.dsets.utils import checksum, EXT2TYPE | |
from llmtuner.extras.logging import get_logger | |
if TYPE_CHECKING: | |
from datasets import Dataset, IterableDataset | |
from llmtuner.hparams import ModelArguments, DataArguments | |
logger = get_logger(__name__) | |
def get_dataset( | |
model_args: "ModelArguments", | |
data_args: "DataArguments" | |
) -> Union["Dataset", "IterableDataset"]: | |
max_samples = data_args.max_samples | |
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets | |
for dataset_attr in data_args.dataset_list: | |
logger.info("Loading dataset {}...".format(dataset_attr)) | |
if dataset_attr.load_from == "hf_hub": | |
data_path = dataset_attr.dataset_name | |
data_files = None | |
elif dataset_attr.load_from == "script": | |
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) | |
data_files = None | |
elif dataset_attr.load_from == "file": | |
data_path = None | |
data_files: List[str] = [] | |
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # directory | |
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): | |
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name)) | |
if data_path is None: | |
data_path = EXT2TYPE.get(file_name.split(".")[-1], None) | |
else: | |
assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file type does not match." | |
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # single file | |
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)) | |
data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None) | |
else: | |
raise ValueError("File not found.") | |
assert data_path, "File extension must be txt, csv, json or jsonl." | |
checksum(data_files, dataset_attr.dataset_sha1) | |
else: | |
raise NotImplementedError | |
dataset = load_dataset( | |
data_path, | |
data_files=data_files, | |
split=data_args.split, | |
cache_dir=model_args.cache_dir, | |
streaming=data_args.streaming, | |
use_auth_token=True if model_args.use_auth_token else None | |
) | |
if max_samples is not None: | |
max_samples_temp = min(len(dataset), max_samples) | |
dataset = dataset.select(range(max_samples_temp)) | |
for column_name in ["prompt", "query", "response", "history"]: # align datasets | |
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name: | |
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name) | |
if dataset_attr.system_prompt: # add system prompt | |
if data_args.streaming: | |
dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt}) | |
else: | |
dataset = dataset.add_column("system", [dataset_attr.system_prompt] * len(dataset)) | |
all_datasets.append(dataset) | |
if len(data_args.dataset_list) == 1: | |
return all_datasets[0] | |
elif data_args.mix_strategy == "concat": | |
if data_args.streaming: | |
logger.warning("The samples between different datasets will not be mixed in streaming mode.") | |
return concatenate_datasets(all_datasets) | |
elif data_args.mix_strategy.startswith("interleave"): | |
if not data_args.streaming: | |
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.") | |
stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted" | |
return interleave_datasets(all_datasets, data_args.interleave_probs, stopping_strategy=stopping_strategy) | |
else: | |
raise ValueError("Unknown mixing strategy.") | |