File size: 15,728 Bytes
de4ade4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import logging
import warnings
from typing import Any, Dict, List, Optional, Union

import torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

log = logging.getLogger(__name__)

# HuggingFace hardcodes the ignore index to -100
_HF_IGNORE_INDEX = -100


class Seq2SeqFinetuningCollator:
    """A general-purpose collator for sequence-to-sequence training/evaluation.

    Args:
        tokenizer: A HuggingFace tokenizer. Must have a pad_token set.
        max_seq_len (int): The maximum sequence length of the combined
            context/target sequence (decoder-only format) or of each the
            context sequence and target sequence (encoder-decoder format).
        decoder_only_format (bool): Whether to format the batches for a
            decoder-only model (if True) or an encoder-decoder model (if False).
        allow_pad_trimming (bool, optional): Whether to allow the collator
            to trim padding, which may result in smaller but inconsistent batch
            sizes. Default: ``False`` ensures that all sequences are max_seq_len.
        separator_text (str | bool, optional): If a string is provided, it will
            be used to separate the context and target sequences (appended to end
            of context). If ``True``, will use the tokenizer's sep_token, which must
            be defined. Only applicable for decoder-only formatting.
        format_for_generation (bool, optional): Whether to format the batch such
            that context and target sequences remain separated, which is useful
            when using the context to generate text which should be compared to the
            target (e.g., during evaluation). Default: ``False``.
        batch_metadata (dict, optional): A dictionary of metadata which will be added
            to the batch.
    """

    def __init__(
        self,
        tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
        max_seq_len: int,
        decoder_only_format: bool,
        allow_pad_trimming: bool = False,
        separator_text: Optional[Union[str, bool]] = None,
        format_for_generation: bool = False,
        batch_metadata: Optional[Dict[str, Any]] = None,
    ):
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len
        self.decoder_only_format = decoder_only_format
        self.format_for_generation = format_for_generation
        self.batch_metadata = batch_metadata or {}

        # Trimming will always be skipped on at least the first __call__
        self._allow_pad_trimming = allow_pad_trimming
        self._seen_first_batch = False

        illegal_keys = [
            'input_ids', 'labels', 'attention_mask', 'decoder_input_ids',
            'decoder_attention_mask', 'generate_output'
        ]
        found_keys = []
        for illegal_key in illegal_keys:
            if illegal_key in self.batch_metadata:
                found_keys.append(illegal_key)
        if found_keys:
            raise ValueError(
                f'The following keys are in batch_metadata but are not allowed: {", ".join(found_keys)}.\n' +\
                f'You cannot use keys that are used directly by the models. The prohibited keys are:\n' +\
                f'{", ".join(illegal_keys)}'
            )
        if self.format_for_generation:
            self.batch_metadata['generate_output'] = True

        if (max_seq_len % 8) != 0:
            log.warning(
                'For performance, a max_seq_len as a multiple of 8 is recommended.'
            )

        if self.tokenizer.pad_token_id is None:
            raise ValueError(
                f'{self.__class__.__name__} requires that the tokenizer has the pad token set, but it is None'
            )

        self.separator_tokens = []
        if separator_text and decoder_only_format:
            if separator_text == True:
                # Use the tokenizer's sep token or throw an error if undefined
                if self.tokenizer.sep_token_id is None:
                    raise ValueError(
                        'Setting separator_text=True requires that the tokenizer has sep_token_id but it has not been set. ' +\
                        'Please pass a string argument for separator_text or set sep_token_id in the tokenizer.'
                    )
                self.separator_tokens = [self.tokenizer.sep_token_id]
            else:
                # Convert the string separator_text into token(s)
                self.separator_tokens = tokenizer(
                    separator_text, add_special_tokens=False).input_ids

        self._warned_context = False
        self._warned_target = False

    def __call__(self, examples: List[Dict[str,
                                           Any]]) -> Dict[str, torch.Tensor]:
        for check_key in ['input_ids', 'labels', 'attention_mask']:
            if check_key not in examples[0]:
                raise KeyError(
                    f'Examples returned by dataset do not include required key: {check_key}'
                )

        if self.decoder_only_format:
            batch = self._process_and_batch_decoder_only(examples)
        else:
            batch = self._process_and_batch_encoder_decoder(examples)

        # Add any batch_metadata
        batch_size = batch['input_ids'].shape[0]
        batch.update({
            k: torch.tensor([v] * batch_size)
            for k, v in self.batch_metadata.items()
        })

        return batch

    def _process_and_batch_decoder_only(
            self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        # Steps explained in comments
        processed_examples = []
        for example in examples:
            context = ensure_list(example['input_ids'])
            target = ensure_list(example['labels'])
            # First, get rid of any padding tokens
            context = [t for t in context if t != self.tokenizer.pad_token_id]
            target = [t for t in target if t != self.tokenizer.pad_token_id]
            # Second, append any separator tokens to the context tokens
            if self.separator_tokens:
                context = context + self.separator_tokens
            # Third, ensure that the target text ends with an eos tag
            if target[-1] != self.tokenizer.eos_token_id:
                target = target + [self.tokenizer.eos_token_id]

            n_context = len(context)
            n_target = len(target)

            if n_context >= self.max_seq_len:
                if not self._warned_context:
                    warnings.warn(
                        f'Skipping example because CONTEXT length={n_context} leaves no room ' +\
                        f'for TARGET tokens because max_seq_len={self.max_seq_len}. ' +\
                        f'If this causes downstream issues because of inconsistent batch sizes, ' +\
                        f'consider increasing max_seq_len or using example packing.'
                    )
                    self._warned_context = True
                continue

            if self.format_for_generation:
                # When formatting for generation, we need to keep input_ids and
                # labels separate. The input_ids (context) will be fed into the
                # generator and the labels will be used by the eval metric.
                input_ids = context[-self.max_seq_len:]
                n_context = len(input_ids)
                attention_mask = [1] * n_context
                bidirectional_mask = [1] * n_context
                # Annoyingly, we need to pad the everything but input_ids
                # and attention_mask ourselves
                i_pad = [self.tokenizer.pad_token_id
                        ] * (self.max_seq_len - n_target)
                z_pad = [0] * (self.max_seq_len - n_context)
                if self.tokenizer.padding_side == 'left':
                    labels = i_pad + target
                    bidirectional_mask = z_pad + bidirectional_mask
                else:
                    labels = target + i_pad
                    bidirectional_mask = bidirectional_mask + z_pad

            else:
                # We need to concatenate the context and target to get the
                # full input sequence, cutting off any excess tokens from the
                # end of the target
                if n_context + n_target > self.max_seq_len:
                    old_n_target = int(n_target)
                    n_target = self.max_seq_len - n_context
                    if not self._warned_target:
                        warnings.warn(
                            f'Truncating TARGET sequence of length={old_n_target} to length={n_target}, ' +\
                            f'so context+target fit max_seq_len={self.max_seq_len}. If truncation is ' +\
                            f'a problem, consider increasing max_seq_len.')
                        self._warned_target = True
                    target = target[-n_target:]
                    target[-1] = self.tokenizer.eos_token_id
                n_total = n_context + n_target

                input_ids = context + target
                labels = ([_HF_IGNORE_INDEX] * n_context) + target
                attention_mask = [1] * n_total
                # bidirectional_mask is used by our prefix lm model variants
                bidirectional_mask = ([1] * n_context) + ([0] * n_target)

                # Annoyingly, we need to pad the everything but input_ids
                # and attention_mask ourselves
                i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - n_total)
                z_pad = [0] * (self.max_seq_len - n_total)
                if self.tokenizer.padding_side == 'left':
                    labels = i_pad + labels
                    bidirectional_mask = z_pad + bidirectional_mask
                else:
                    labels = labels + i_pad
                    bidirectional_mask = bidirectional_mask + z_pad

            # Update the example
            example['input_ids'] = input_ids
            example['labels'] = labels
            example['attention_mask'] = attention_mask
            example['bidirectional_mask'] = bidirectional_mask

            processed_examples.append(example)

        batch = self.tokenizer.pad(
            processed_examples,
            padding='max_length',
            max_length=self.max_seq_len,
            return_tensors='pt',
        )

        # This logic prevents trimming on at least the first batch
        if not (self._allow_pad_trimming and self._seen_first_batch):
            self._seen_first_batch = True
            return batch
        self._seen_first_batch = True

        # The batch is ready, but we can trim padding for efficiency
        multiple_of = 8

        n_non_padding = batch['attention_mask'].sum(dim=1).max()
        keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
        for k, v in batch.items():
            if len(v.shape) < 2:
                continue
            if k == 'labels' and self.format_for_generation:
                continue
            if self.tokenizer.padding_side == 'left':
                batch[k] = v[:, -keep_tokens:].contiguous()
            else:
                batch[k] = v[:, :keep_tokens].contiguous()

        return batch

    def _process_and_batch_encoder_decoder(
            self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        # The encoder-decoder case is has some gotchas.
        # Steps are explained in comments.
        processed_examples = []
        for example in examples:
            context = ensure_list(example['input_ids'])
            target = ensure_list(example['labels'])
            # ... first, get rid of any padding that was already applied
            context = [t for t in context if t != self.tokenizer.pad_token_id]
            target = [t for t in target if t != self.tokenizer.pad_token_id]
            # ... second, ensure that the target text ends with an eos tag
            if target[-1] != self.tokenizer.eos_token_id:
                target = target + [self.tokenizer.eos_token_id]
            # ... third, we need to pad labels ourselves. Because HF.
            if len(target) < self.max_seq_len:
                i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - len(target))
                target = target + i_pad
            else:
                if not self._warned_target:
                    warnings.warn(
                        f'Truncating TARGET sequence of length={len(target)} ' +\
                        f'to max_seq_len={self.max_seq_len}. If truncation is ' +\
                        f'a problem, consider increasing max_seq_len.')
                    self._warned_target = True
                target = target[:self.max_seq_len -
                                1] + [self.tokenizer.eos_token_id]

            # We might need to truncate the context. Preserve the beginning.
            if len(context) > self.max_seq_len:
                if not self._warned_context:
                    warnings.warn(
                        f'Truncating CONTEXT sequence of length={len(context)} ' +\
                        f'to max_seq_len={self.max_seq_len}. If truncation is ' +\
                        f'a problem, consider increasing max_seq_len.')
                    self._warned_context = True
                context = context[:self.max_seq_len -
                                  1] + [self.tokenizer.eos_token_id]

            # Back into the example
            example['input_ids'] = context
            example['attention_mask'] = [1] * len(context)
            example['labels'] = target

            processed_examples.append(example)

        # Batch examples into a single dict (this also pads)
        batch = self.tokenizer.pad(
            processed_examples,
            padding='max_length',
            max_length=self.max_seq_len,
            return_tensors='pt',
        )
        # We're still missing decoder_input_ids and decoder_attention_mask
        batch['decoder_input_ids'] = torch.cat([
            torch.full((len(processed_examples), 1),
                       self.tokenizer.pad_token_id), batch['labels'][:, :-1]
        ],
                                               dim=1)
        batch['decoder_input_ids'].masked_fill_(
            batch['decoder_input_ids'] == _HF_IGNORE_INDEX,
            self.tokenizer.pad_token_id)
        batch['decoder_attention_mask'] = torch.not_equal(
            batch['labels'], _HF_IGNORE_INDEX)

        # This logic prevents trimming on at least the first batch
        if not (self._allow_pad_trimming and self._seen_first_batch):
            self._seen_first_batch = True
            return batch
        self._seen_first_batch = True

        # The batch is now valid, but we can trim padding for efficiency
        multiple_of = 8
        # (first for the encoder)
        n_non_padding = batch['attention_mask'].sum(dim=1).max()
        keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
        for k in ['input_ids', 'attention_mask']:
            batch[k] = batch[k][:, :keep_tokens].contiguous()
        # (then for the decoder)
        n_non_padding = batch['decoder_attention_mask'].sum(dim=1).max()
        keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
        for k in ['decoder_input_ids', 'decoder_attention_mask', 'labels']:
            batch[k] = batch[k][:, :keep_tokens].contiguous()

        return batch


def ensure_list(x: Union[List, torch.Tensor]) -> List:
    if isinstance(x, torch.Tensor):
        x = list(x.flatten())
    assert isinstance(x, list)
    return x