File size: 8,598 Bytes
de4ade4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import os
import platform
from argparse import ArgumentParser, Namespace
from typing import Dict, Iterable, List, Optional, Union

import datasets as hf_datasets
import psutil
from streaming import MDSWriter
from torch.utils.data import DataLoader, IterableDataset
from tqdm import tqdm

from llmfoundry.data.finetuning.tasks import dataset_constructor


def parse_args() -> Namespace:
    """Parse commandline arguments."""
    parser = ArgumentParser(description='Convert dataset into MDS format.')
    parser.add_argument(
        '--dataset',
        type=str,
        required=True,
        help=
        'Name/path of the dataset (e.g., first argument to `datasets.load_dataset`)'
    )
    parser.add_argument('--data_subset',
                        type=str,
                        default=None,
                        help='(Optional) subset of data to use.')
    parser.add_argument('--splits',
                        nargs='+',
                        default=['train', 'validation'],
                        help='Which splits of the dataset to convert.')
    parser.add_argument('--preprocessor',
                        type=str,
                        default=None,
                        help='Name or import path of function used to preprocess (reformat) the dataset. ' +\
                             'See README for additional details.')
    parser.add_argument(
        '--skip-preprocessing',
        action='store_true',
        help=
        'Whether to skip preprocesing (e.g., if the dataset is already formatted correctly)'
    )
    parser.add_argument(
        '--out_root',
        type=str,
        required=True,
        help=
        'Root path of output directory where MDS shards will be stored. Can be a remote URI.'
    )
    parser.add_argument(
        '--local',
        type=str,
        default=None,
        help=
        '(Optional) root path of local directory if you want to keep a local copy when out_root is remote.'
    )
    parser.add_argument('--compression',
                        type=str,
                        default=None,
                        help='(Optional) name of compression algorithm to use.')
    parser.add_argument('--num_workers', type=int, required=False, default=None)

    parsed = parser.parse_args()

    if os.path.isdir(parsed.out_root) and len(
            set(os.listdir(parsed.out_root)).intersection(set(
                parsed.splits))) > 0:
        raise ValueError(
            f'--out_root={parsed.out_root} contains {os.listdir(parsed.out_root)} which cannot overlap with the requested splits {parsed.splits}.'
        )

    return parsed


class SimpleDataset(IterableDataset):
    """An IterableDataset that returns text samples for MDSWriter.

    Returns dicts of {'key': bytes} for each 'key' in `columns`
    """

    def __init__(self, dataset_name: str, data_subset: Union[str, None],
                 split: str, columns: List[str]):
        self.hf_dataset = hf_datasets.load_dataset(path=dataset_name,
                                                   name=data_subset,
                                                   split=split,
                                                   streaming=True)
        self.columns = columns

    def __iter__(self) -> Iterable[Dict[str, bytes]]:
        for sample in self.hf_dataset:
            # convert to bytes to store in MDS binary format
            yield {key: sample[key].encode('utf-8') for key in self.columns}


def build_dataloader(dataset: SimpleDataset,
                     batch_size: int,
                     num_workers: Optional[int] = None) -> DataLoader:
    if num_workers is None:
        # Multiple workers is only supported on linux machines
        if 'linux' in platform.platform().lower():
            num_workers = max(1, psutil.cpu_count())
        else:
            num_workers = 0

    # If using multiple workers, configure each worker to prefetch as many samples as it can, up to
    # the aggregate device batch size
    # If not using workers, the torch DataLoader expects the default value for prefetch_factor,
    # which non-intuitively must be 2.
    # If on macOS, PyTorch requires prefetch_factor set to None since num_workers is always zero
    if 'macos' in platform.platform().lower() and num_workers == 0:
        prefetch_factor = None
    else:
        prefetch_factor = max(1, 2 * batch_size //
                              num_workers) if num_workers > 0 else 2

    return DataLoader(
        dataset=dataset,
        sampler=None,
        batch_size=batch_size,
        num_workers=num_workers,
        prefetch_factor=prefetch_factor,
    )


def generate_samples(
        loader: DataLoader,
        truncate_num_samples: Optional[int] = None
) -> Iterable[Dict[str, bytes]]:
    """Generator over samples of a dataloader.

    Args:
       loader (DataLoader): A dataloader emitting batches like {key: [sample0_bytes, sample1_bytes, sample2_bytes, ...]}
       truncate_num_samples (Optional[int]): An optional # of samples to stop at.

    Yields:
        Sample dicts.
    """
    n_samples = 0
    for batch in loader:
        keys = list(batch.keys())
        current_bs = len(batch[keys[0]])
        for idx in range(current_bs):
            if truncate_num_samples is not None and n_samples == truncate_num_samples:
                return
            n_samples += 1
            yield {k: v[idx] for k, v in batch.items()}


def main(args: Namespace) -> None:
    """Main: create a streaming dataset.

    Args:
        args (Namespace): Commandline arguments.
    """
    if args.skip_preprocessing:
        preprocessing_fn = lambda x: x  # Just an identity function
    else:
        preprocessor_str = args.preprocessor
        preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_str(
            preprocessor=preprocessor_str, dataset_name=args.dataset)
        if preprocessing_fn is None:
            raise ValueError(
                '`args.preprocessor` was not set and no preprocessing function ' +\
                'has been registered for `args.dataset`. If this was intentional ' +\
                '(e.g., because your dataset is already correctly formatted), ' +\
                'include the "--skip-preprocessing" flag to avoid this error.'
            )

    columns = ['prompt', 'response']

    for split_name in args.splits:
        dataset = hf_datasets.load_dataset(path=args.dataset,
                                           name=args.data_subset,
                                           split=split_name,
                                           streaming=True)
        loader = build_dataloader(dataset=dataset,
                                  batch_size=512,
                                  num_workers=args.num_workers)
        samples = generate_samples(loader)

        # Write samples
        print(f'Converting {split_name} to MDS format...')
        out = os.path.join(args.out_root, split_name)
        if args.local is not None:
            out = (os.path.join(args.local, split_name), out)
            keep_local = True
        else:
            keep_local = False
        with MDSWriter(columns={key: 'str' for key in columns},
                       out=out,
                       compression=args.compression,
                       keep_local=keep_local) as out:
            for sample in tqdm(samples, desc=split_name):
                formatted_sample = preprocessing_fn(sample)
                if ('prompt'
                        not in formatted_sample) or ('response'
                                                     not in formatted_sample):
                    raise KeyError(
                        'Unable to tokenize example because it has not been properly formatted. ' +\
                        '"prompt" and "response" are required keys but at least one was missing ' +\
                        f'from {formatted_sample=}.'
                    )
                encoded_sample = {
                    key: formatted_sample[key].encode('utf-8')
                    for key in columns
                }
                out.write(encoded_sample)


if __name__ == '__main__':
    """Example for converting Muennighoff/P3:

    >>> python convert_finetuning_dataset.py \
    >>>    --dataset "Muennighoff/P3" \
    >>>    --splits train validation \
    >>>    --preprocessor llmfoundry.data.finetuning.tasks:p3_preprocessing_function \
    >>>    --out_root s3://<bucket>/muennighoff-p3
    """
    main(parse_args())