|
|
|
|
|
|
|
import logging |
|
import warnings |
|
from typing import Any, Dict, List, Optional, Union |
|
|
|
import torch |
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast |
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
_HF_IGNORE_INDEX = -100 |
|
|
|
|
|
class Seq2SeqFinetuningCollator: |
|
"""A general-purpose collator for sequence-to-sequence training/evaluation. |
|
|
|
Args: |
|
tokenizer: A HuggingFace tokenizer. Must have a pad_token set. |
|
max_seq_len (int): The maximum sequence length of the combined |
|
context/target sequence (decoder-only format) or of each the |
|
context sequence and target sequence (encoder-decoder format). |
|
decoder_only_format (bool): Whether to format the batches for a |
|
decoder-only model (if True) or an encoder-decoder model (if False). |
|
allow_pad_trimming (bool, optional): Whether to allow the collator |
|
to trim padding, which may result in smaller but inconsistent batch |
|
sizes. Default: ``False`` ensures that all sequences are max_seq_len. |
|
separator_text (str | bool, optional): If a string is provided, it will |
|
be used to separate the context and target sequences (appended to end |
|
of context). If ``True``, will use the tokenizer's sep_token, which must |
|
be defined. Only applicable for decoder-only formatting. |
|
format_for_generation (bool, optional): Whether to format the batch such |
|
that context and target sequences remain separated, which is useful |
|
when using the context to generate text which should be compared to the |
|
target (e.g., during evaluation). Default: ``False``. |
|
batch_metadata (dict, optional): A dictionary of metadata which will be added |
|
to the batch. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], |
|
max_seq_len: int, |
|
decoder_only_format: bool, |
|
allow_pad_trimming: bool = False, |
|
separator_text: Optional[Union[str, bool]] = None, |
|
format_for_generation: bool = False, |
|
batch_metadata: Optional[Dict[str, Any]] = None, |
|
): |
|
self.tokenizer = tokenizer |
|
self.max_seq_len = max_seq_len |
|
self.decoder_only_format = decoder_only_format |
|
self.format_for_generation = format_for_generation |
|
self.batch_metadata = batch_metadata or {} |
|
|
|
|
|
self._allow_pad_trimming = allow_pad_trimming |
|
self._seen_first_batch = False |
|
|
|
illegal_keys = [ |
|
'input_ids', 'labels', 'attention_mask', 'decoder_input_ids', |
|
'decoder_attention_mask', 'generate_output' |
|
] |
|
found_keys = [] |
|
for illegal_key in illegal_keys: |
|
if illegal_key in self.batch_metadata: |
|
found_keys.append(illegal_key) |
|
if found_keys: |
|
raise ValueError( |
|
f'The following keys are in batch_metadata but are not allowed: {", ".join(found_keys)}.\n' +\ |
|
f'You cannot use keys that are used directly by the models. The prohibited keys are:\n' +\ |
|
f'{", ".join(illegal_keys)}' |
|
) |
|
if self.format_for_generation: |
|
self.batch_metadata['generate_output'] = True |
|
|
|
if (max_seq_len % 8) != 0: |
|
log.warning( |
|
'For performance, a max_seq_len as a multiple of 8 is recommended.' |
|
) |
|
|
|
if self.tokenizer.pad_token_id is None: |
|
raise ValueError( |
|
f'{self.__class__.__name__} requires that the tokenizer has the pad token set, but it is None' |
|
) |
|
|
|
self.separator_tokens = [] |
|
if separator_text and decoder_only_format: |
|
if separator_text == True: |
|
|
|
if self.tokenizer.sep_token_id is None: |
|
raise ValueError( |
|
'Setting separator_text=True requires that the tokenizer has sep_token_id but it has not been set. ' +\ |
|
'Please pass a string argument for separator_text or set sep_token_id in the tokenizer.' |
|
) |
|
self.separator_tokens = [self.tokenizer.sep_token_id] |
|
else: |
|
|
|
self.separator_tokens = tokenizer( |
|
separator_text, add_special_tokens=False).input_ids |
|
|
|
self._warned_context = False |
|
self._warned_target = False |
|
|
|
def __call__(self, examples: List[Dict[str, |
|
Any]]) -> Dict[str, torch.Tensor]: |
|
for check_key in ['input_ids', 'labels', 'attention_mask']: |
|
if check_key not in examples[0]: |
|
raise KeyError( |
|
f'Examples returned by dataset do not include required key: {check_key}' |
|
) |
|
|
|
if self.decoder_only_format: |
|
batch = self._process_and_batch_decoder_only(examples) |
|
else: |
|
batch = self._process_and_batch_encoder_decoder(examples) |
|
|
|
|
|
batch_size = batch['input_ids'].shape[0] |
|
batch.update({ |
|
k: torch.tensor([v] * batch_size) |
|
for k, v in self.batch_metadata.items() |
|
}) |
|
|
|
return batch |
|
|
|
def _process_and_batch_decoder_only( |
|
self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: |
|
|
|
processed_examples = [] |
|
for example in examples: |
|
context = ensure_list(example['input_ids']) |
|
target = ensure_list(example['labels']) |
|
|
|
context = [t for t in context if t != self.tokenizer.pad_token_id] |
|
target = [t for t in target if t != self.tokenizer.pad_token_id] |
|
|
|
if self.separator_tokens: |
|
context = context + self.separator_tokens |
|
|
|
if target[-1] != self.tokenizer.eos_token_id: |
|
target = target + [self.tokenizer.eos_token_id] |
|
|
|
n_context = len(context) |
|
n_target = len(target) |
|
|
|
if n_context >= self.max_seq_len: |
|
if not self._warned_context: |
|
warnings.warn( |
|
f'Skipping example because CONTEXT length={n_context} leaves no room ' +\ |
|
f'for TARGET tokens because max_seq_len={self.max_seq_len}. ' +\ |
|
f'If this causes downstream issues because of inconsistent batch sizes, ' +\ |
|
f'consider increasing max_seq_len or using example packing.' |
|
) |
|
self._warned_context = True |
|
continue |
|
|
|
if self.format_for_generation: |
|
|
|
|
|
|
|
input_ids = context[-self.max_seq_len:] |
|
n_context = len(input_ids) |
|
attention_mask = [1] * n_context |
|
bidirectional_mask = [1] * n_context |
|
|
|
|
|
i_pad = [self.tokenizer.pad_token_id |
|
] * (self.max_seq_len - n_target) |
|
z_pad = [0] * (self.max_seq_len - n_context) |
|
if self.tokenizer.padding_side == 'left': |
|
labels = i_pad + target |
|
bidirectional_mask = z_pad + bidirectional_mask |
|
else: |
|
labels = target + i_pad |
|
bidirectional_mask = bidirectional_mask + z_pad |
|
|
|
else: |
|
|
|
|
|
|
|
if n_context + n_target > self.max_seq_len: |
|
old_n_target = int(n_target) |
|
n_target = self.max_seq_len - n_context |
|
if not self._warned_target: |
|
warnings.warn( |
|
f'Truncating TARGET sequence of length={old_n_target} to length={n_target}, ' +\ |
|
f'so context+target fit max_seq_len={self.max_seq_len}. If truncation is ' +\ |
|
f'a problem, consider increasing max_seq_len.') |
|
self._warned_target = True |
|
target = target[-n_target:] |
|
target[-1] = self.tokenizer.eos_token_id |
|
n_total = n_context + n_target |
|
|
|
input_ids = context + target |
|
labels = ([_HF_IGNORE_INDEX] * n_context) + target |
|
attention_mask = [1] * n_total |
|
|
|
bidirectional_mask = ([1] * n_context) + ([0] * n_target) |
|
|
|
|
|
|
|
i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - n_total) |
|
z_pad = [0] * (self.max_seq_len - n_total) |
|
if self.tokenizer.padding_side == 'left': |
|
labels = i_pad + labels |
|
bidirectional_mask = z_pad + bidirectional_mask |
|
else: |
|
labels = labels + i_pad |
|
bidirectional_mask = bidirectional_mask + z_pad |
|
|
|
|
|
example['input_ids'] = input_ids |
|
example['labels'] = labels |
|
example['attention_mask'] = attention_mask |
|
example['bidirectional_mask'] = bidirectional_mask |
|
|
|
processed_examples.append(example) |
|
|
|
batch = self.tokenizer.pad( |
|
processed_examples, |
|
padding='max_length', |
|
max_length=self.max_seq_len, |
|
return_tensors='pt', |
|
) |
|
|
|
|
|
if not (self._allow_pad_trimming and self._seen_first_batch): |
|
self._seen_first_batch = True |
|
return batch |
|
self._seen_first_batch = True |
|
|
|
|
|
multiple_of = 8 |
|
|
|
n_non_padding = batch['attention_mask'].sum(dim=1).max() |
|
keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of)) |
|
for k, v in batch.items(): |
|
if len(v.shape) < 2: |
|
continue |
|
if k == 'labels' and self.format_for_generation: |
|
continue |
|
if self.tokenizer.padding_side == 'left': |
|
batch[k] = v[:, -keep_tokens:].contiguous() |
|
else: |
|
batch[k] = v[:, :keep_tokens].contiguous() |
|
|
|
return batch |
|
|
|
def _process_and_batch_encoder_decoder( |
|
self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: |
|
|
|
|
|
processed_examples = [] |
|
for example in examples: |
|
context = ensure_list(example['input_ids']) |
|
target = ensure_list(example['labels']) |
|
|
|
context = [t for t in context if t != self.tokenizer.pad_token_id] |
|
target = [t for t in target if t != self.tokenizer.pad_token_id] |
|
|
|
if target[-1] != self.tokenizer.eos_token_id: |
|
target = target + [self.tokenizer.eos_token_id] |
|
|
|
if len(target) < self.max_seq_len: |
|
i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - len(target)) |
|
target = target + i_pad |
|
else: |
|
if not self._warned_target: |
|
warnings.warn( |
|
f'Truncating TARGET sequence of length={len(target)} ' +\ |
|
f'to max_seq_len={self.max_seq_len}. If truncation is ' +\ |
|
f'a problem, consider increasing max_seq_len.') |
|
self._warned_target = True |
|
target = target[:self.max_seq_len - |
|
1] + [self.tokenizer.eos_token_id] |
|
|
|
|
|
if len(context) > self.max_seq_len: |
|
if not self._warned_context: |
|
warnings.warn( |
|
f'Truncating CONTEXT sequence of length={len(context)} ' +\ |
|
f'to max_seq_len={self.max_seq_len}. If truncation is ' +\ |
|
f'a problem, consider increasing max_seq_len.') |
|
self._warned_context = True |
|
context = context[:self.max_seq_len - |
|
1] + [self.tokenizer.eos_token_id] |
|
|
|
|
|
example['input_ids'] = context |
|
example['attention_mask'] = [1] * len(context) |
|
example['labels'] = target |
|
|
|
processed_examples.append(example) |
|
|
|
|
|
batch = self.tokenizer.pad( |
|
processed_examples, |
|
padding='max_length', |
|
max_length=self.max_seq_len, |
|
return_tensors='pt', |
|
) |
|
|
|
batch['decoder_input_ids'] = torch.cat([ |
|
torch.full((len(processed_examples), 1), |
|
self.tokenizer.pad_token_id), batch['labels'][:, :-1] |
|
], |
|
dim=1) |
|
batch['decoder_input_ids'].masked_fill_( |
|
batch['decoder_input_ids'] == _HF_IGNORE_INDEX, |
|
self.tokenizer.pad_token_id) |
|
batch['decoder_attention_mask'] = torch.not_equal( |
|
batch['labels'], _HF_IGNORE_INDEX) |
|
|
|
|
|
if not (self._allow_pad_trimming and self._seen_first_batch): |
|
self._seen_first_batch = True |
|
return batch |
|
self._seen_first_batch = True |
|
|
|
|
|
multiple_of = 8 |
|
|
|
n_non_padding = batch['attention_mask'].sum(dim=1).max() |
|
keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of)) |
|
for k in ['input_ids', 'attention_mask']: |
|
batch[k] = batch[k][:, :keep_tokens].contiguous() |
|
|
|
n_non_padding = batch['decoder_attention_mask'].sum(dim=1).max() |
|
keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of)) |
|
for k in ['decoder_input_ids', 'decoder_attention_mask', 'labels']: |
|
batch[k] = batch[k][:, :keep_tokens].contiguous() |
|
|
|
return batch |
|
|
|
|
|
def ensure_list(x: Union[List, torch.Tensor]) -> List: |
|
if isinstance(x, torch.Tensor): |
|
x = list(x.flatten()) |
|
assert isinstance(x, list) |
|
return x |
|
|