|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import os |
|
from dataclasses import dataclass |
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional |
|
|
|
from ..extras.constants import DATA_CONFIG |
|
from ..extras.misc import use_modelscope |
|
|
|
|
|
if TYPE_CHECKING: |
|
from ..hparams import DataArguments |
|
|
|
|
|
@dataclass |
|
class DatasetAttr: |
|
r""" |
|
Dataset attributes. |
|
""" |
|
|
|
|
|
load_from: Literal["hf_hub", "ms_hub", "script", "file"] |
|
dataset_name: str |
|
formatting: Literal["alpaca", "sharegpt", "molqa"] = "molqa" |
|
ranking: bool = False |
|
|
|
subset: Optional[str] = None |
|
folder: Optional[str] = None |
|
num_samples: Optional[int] = None |
|
|
|
system: Optional[str] = None |
|
tools: Optional[str] = None |
|
images: Optional[str] = None |
|
|
|
chosen: Optional[str] = None |
|
rejected: Optional[str] = None |
|
kto_tag: Optional[str] = None |
|
|
|
prompt: Optional[str] = "instruction" |
|
query: Optional[str] = "input" |
|
response: Optional[str] = "output" |
|
history: Optional[str] = None |
|
|
|
messages: Optional[str] = "conversations" |
|
|
|
role_tag: Optional[str] = "from" |
|
content_tag: Optional[str] = "value" |
|
user_tag: Optional[str] = "human" |
|
assistant_tag: Optional[str] = "gpt" |
|
observation_tag: Optional[str] = "observation" |
|
function_tag: Optional[str] = "function_call" |
|
system_tag: Optional[str] = "system" |
|
|
|
property: Optional[str] = 'property' |
|
retro: Optional[str] = 'retro' |
|
|
|
|
|
def __repr__(self) -> str: |
|
return self.dataset_name |
|
|
|
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None: |
|
setattr(self, key, obj.get(key, default)) |
|
|
|
def get_dataset_attr(data_args: "DataArguments") -> List["DatasetAttr"]: |
|
if data_args.dataset is not None: |
|
dataset_name = data_args.dataset.strip() |
|
else: |
|
raise ValueError("Please specify the dataset name.") |
|
|
|
try: |
|
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f: |
|
dataset_info = json.load(f) |
|
except Exception as err: |
|
raise ValueError( |
|
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err)) |
|
) |
|
dataset_info = None |
|
|
|
if dataset_name not in dataset_info: |
|
raise ValueError("Undefined dataset {} in {}.".format(dataset_name, DATA_CONFIG)) |
|
|
|
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[dataset_name]["file_name"]) |
|
|
|
print('dataset_info', dataset_info) |
|
|
|
dataset_attr.set_attr("formatting", dataset_info[dataset_name], default="molqa") |
|
dataset_attr.set_attr("ranking", dataset_info[dataset_name], default=False) |
|
dataset_attr.set_attr("subset", dataset_info[dataset_name]) |
|
dataset_attr.set_attr("folder", dataset_info[dataset_name]) |
|
dataset_attr.set_attr("num_samples", dataset_info[dataset_name]) |
|
|
|
if "columns" in dataset_info[dataset_name]: |
|
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"] |
|
assert dataset_attr.formatting == "molqa" |
|
column_names.extend(["prompt", "query", "response", "history", "property", "retro"]) |
|
|
|
for column_name in column_names: |
|
dataset_attr.set_attr(column_name, dataset_info[dataset_name]["columns"]) |
|
|
|
return dataset_attr |