File size: 2,189 Bytes
f07a9b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
from datasets import load_dataset
from torchtune.data import StackExchangedPairedTemplate
from torchtune.datasets._preference import PreferenceDataset
from torchtune.modules.tokenizers import Tokenizer
from typing import Optional, Tuple, List
def extract_assistant_content(sample):
"""
Extracts the text content of the assistant response from the lists of messages.
Args:
sample (dict): A dictionary containing the prompt, chosen, and rejected lists of messages.
Returns:
dict: The original sample dictionary with the extracted assistant content.
"""
sample['chosen'] = sample['chosen'][-1]['content']
sample['rejected'] = sample['rejected'][-1]['content']
return sample
class ModifiedPreferenceDataset(PreferenceDataset):
def __getitem__(self, index: int) -> Tuple[List[int], List[int], List[int], List[int]]:
sample = self._data[index]
batch = self._prepare_sample(sample)
return (
batch["chosen_input_ids"],
batch["chosen_labels"],
batch["rejected_input_ids"],
batch["rejected_labels"],
)
def orpo_dpo_mix_40k_dataset(
tokenizer: Tokenizer,
*,
max_seq_len: int = 8192,
) -> ModifiedPreferenceDataset:
"""
Preference dataset for the 'mlabonne/orpo-dpo-mix-40k' dataset.
Args:
tokenizer (Tokenizer): Tokenizer used to encode data.
max_seq_len (int): Maximum number of tokens in the returned input and label token id lists.
Default is 8192.
data_dir (str): Directory to store the downloaded dataset. Default is "data".
Returns:
ModifiedPreferenceDataset: The modified preference dataset built from the 'mlabonne/orpo-dpo-mix-40k' dataset.
"""
return ModifiedPreferenceDataset(
tokenizer=tokenizer,
source="mlabonne/orpo-dpo-mix-40k",
template=StackExchangedPairedTemplate(),
transform=extract_assistant_content,
column_map={
"prompt": "prompt",
"chosen": "chosen",
"rejected": "rejected",
},
max_seq_len=max_seq_len,
split="train",
data_dir="data"
)
|