File size: 2,189 Bytes
f07a9b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from datasets import load_dataset
from torchtune.data import StackExchangedPairedTemplate
from torchtune.datasets._preference import PreferenceDataset
from torchtune.modules.tokenizers import Tokenizer
from typing import Optional, Tuple, List

def extract_assistant_content(sample):
    """
    Extracts the text content of the assistant response from the lists of messages.
    
    Args:
        sample (dict): A dictionary containing the prompt, chosen, and rejected lists of messages.
    
    Returns:
        dict: The original sample dictionary with the extracted assistant content.
    """
    sample['chosen'] = sample['chosen'][-1]['content']
    sample['rejected'] = sample['rejected'][-1]['content']
    return sample

class ModifiedPreferenceDataset(PreferenceDataset):
    def __getitem__(self, index: int) -> Tuple[List[int], List[int], List[int], List[int]]:
        sample = self._data[index]
        batch = self._prepare_sample(sample)
        return (
            batch["chosen_input_ids"],
            batch["chosen_labels"],
            batch["rejected_input_ids"],
            batch["rejected_labels"],
        )

def orpo_dpo_mix_40k_dataset(
    tokenizer: Tokenizer,
    *,
    max_seq_len: int = 8192,
) -> ModifiedPreferenceDataset:
    """
    Preference dataset for the 'mlabonne/orpo-dpo-mix-40k' dataset.

    Args:
        tokenizer (Tokenizer): Tokenizer used to encode data.
        max_seq_len (int): Maximum number of tokens in the returned input and label token id lists.
            Default is 8192.
        data_dir (str): Directory to store the downloaded dataset. Default is "data".

    Returns:
        ModifiedPreferenceDataset: The modified preference dataset built from the 'mlabonne/orpo-dpo-mix-40k' dataset.
    """
    return ModifiedPreferenceDataset(
        tokenizer=tokenizer,
        source="mlabonne/orpo-dpo-mix-40k",
        template=StackExchangedPairedTemplate(),
        transform=extract_assistant_content,
        column_map={
            "prompt": "prompt",
            "chosen": "chosen",
            "rejected": "rejected",
        },
        max_seq_len=max_seq_len,
        split="train",
        data_dir="data"
    )