from datasets import load_dataset from torchtune.data import StackExchangedPairedTemplate from torchtune.datasets._preference import PreferenceDataset from torchtune.modules.tokenizers import Tokenizer from typing import Optional, Tuple, List def extract_assistant_content(sample): """ Extracts the text content of the assistant response from the lists of messages. Args: sample (dict): A dictionary containing the prompt, chosen, and rejected lists of messages. Returns: dict: The original sample dictionary with the extracted assistant content. """ sample['chosen'] = sample['chosen'][-1]['content'] sample['rejected'] = sample['rejected'][-1]['content'] return sample class ModifiedPreferenceDataset(PreferenceDataset): def __getitem__(self, index: int) -> Tuple[List[int], List[int], List[int], List[int]]: sample = self._data[index] batch = self._prepare_sample(sample) return ( batch["chosen_input_ids"], batch["chosen_labels"], batch["rejected_input_ids"], batch["rejected_labels"], ) def orpo_dpo_mix_40k_dataset( tokenizer: Tokenizer, *, max_seq_len: int = 8192, ) -> ModifiedPreferenceDataset: """ Preference dataset for the 'mlabonne/orpo-dpo-mix-40k' dataset. Args: tokenizer (Tokenizer): Tokenizer used to encode data. max_seq_len (int): Maximum number of tokens in the returned input and label token id lists. Default is 8192. data_dir (str): Directory to store the downloaded dataset. Default is "data". Returns: ModifiedPreferenceDataset: The modified preference dataset built from the 'mlabonne/orpo-dpo-mix-40k' dataset. """ return ModifiedPreferenceDataset( tokenizer=tokenizer, source="mlabonne/orpo-dpo-mix-40k", template=StackExchangedPairedTemplate(), transform=extract_assistant_content, column_map={ "prompt": "prompt", "chosen": "chosen", "rejected": "rejected", }, max_seq_len=max_seq_len, split="train", data_dir="data" )