crystal-technologies's picture
Upload 303 files
de4ade4
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import logging
import warnings
from typing import Any, Dict, List, Optional, Union
import torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
log = logging.getLogger(__name__)
# HuggingFace hardcodes the ignore index to -100
_HF_IGNORE_INDEX = -100
class Seq2SeqFinetuningCollator:
"""A general-purpose collator for sequence-to-sequence training/evaluation.
Args:
tokenizer: A HuggingFace tokenizer. Must have a pad_token set.
max_seq_len (int): The maximum sequence length of the combined
context/target sequence (decoder-only format) or of each the
context sequence and target sequence (encoder-decoder format).
decoder_only_format (bool): Whether to format the batches for a
decoder-only model (if True) or an encoder-decoder model (if False).
allow_pad_trimming (bool, optional): Whether to allow the collator
to trim padding, which may result in smaller but inconsistent batch
sizes. Default: ``False`` ensures that all sequences are max_seq_len.
separator_text (str | bool, optional): If a string is provided, it will
be used to separate the context and target sequences (appended to end
of context). If ``True``, will use the tokenizer's sep_token, which must
be defined. Only applicable for decoder-only formatting.
format_for_generation (bool, optional): Whether to format the batch such
that context and target sequences remain separated, which is useful
when using the context to generate text which should be compared to the
target (e.g., during evaluation). Default: ``False``.
batch_metadata (dict, optional): A dictionary of metadata which will be added
to the batch.
"""
def __init__(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
max_seq_len: int,
decoder_only_format: bool,
allow_pad_trimming: bool = False,
separator_text: Optional[Union[str, bool]] = None,
format_for_generation: bool = False,
batch_metadata: Optional[Dict[str, Any]] = None,
):
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self.decoder_only_format = decoder_only_format
self.format_for_generation = format_for_generation
self.batch_metadata = batch_metadata or {}
# Trimming will always be skipped on at least the first __call__
self._allow_pad_trimming = allow_pad_trimming
self._seen_first_batch = False
illegal_keys = [
'input_ids', 'labels', 'attention_mask', 'decoder_input_ids',
'decoder_attention_mask', 'generate_output'
]
found_keys = []
for illegal_key in illegal_keys:
if illegal_key in self.batch_metadata:
found_keys.append(illegal_key)
if found_keys:
raise ValueError(
f'The following keys are in batch_metadata but are not allowed: {", ".join(found_keys)}.\n' +\
f'You cannot use keys that are used directly by the models. The prohibited keys are:\n' +\
f'{", ".join(illegal_keys)}'
)
if self.format_for_generation:
self.batch_metadata['generate_output'] = True
if (max_seq_len % 8) != 0:
log.warning(
'For performance, a max_seq_len as a multiple of 8 is recommended.'
)
if self.tokenizer.pad_token_id is None:
raise ValueError(
f'{self.__class__.__name__} requires that the tokenizer has the pad token set, but it is None'
)
self.separator_tokens = []
if separator_text and decoder_only_format:
if separator_text == True:
# Use the tokenizer's sep token or throw an error if undefined
if self.tokenizer.sep_token_id is None:
raise ValueError(
'Setting separator_text=True requires that the tokenizer has sep_token_id but it has not been set. ' +\
'Please pass a string argument for separator_text or set sep_token_id in the tokenizer.'
)
self.separator_tokens = [self.tokenizer.sep_token_id]
else:
# Convert the string separator_text into token(s)
self.separator_tokens = tokenizer(
separator_text, add_special_tokens=False).input_ids
self._warned_context = False
self._warned_target = False
def __call__(self, examples: List[Dict[str,
Any]]) -> Dict[str, torch.Tensor]:
for check_key in ['input_ids', 'labels', 'attention_mask']:
if check_key not in examples[0]:
raise KeyError(
f'Examples returned by dataset do not include required key: {check_key}'
)
if self.decoder_only_format:
batch = self._process_and_batch_decoder_only(examples)
else:
batch = self._process_and_batch_encoder_decoder(examples)
# Add any batch_metadata
batch_size = batch['input_ids'].shape[0]
batch.update({
k: torch.tensor([v] * batch_size)
for k, v in self.batch_metadata.items()
})
return batch
def _process_and_batch_decoder_only(
self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
# Steps explained in comments
processed_examples = []
for example in examples:
context = ensure_list(example['input_ids'])
target = ensure_list(example['labels'])
# First, get rid of any padding tokens
context = [t for t in context if t != self.tokenizer.pad_token_id]
target = [t for t in target if t != self.tokenizer.pad_token_id]
# Second, append any separator tokens to the context tokens
if self.separator_tokens:
context = context + self.separator_tokens
# Third, ensure that the target text ends with an eos tag
if target[-1] != self.tokenizer.eos_token_id:
target = target + [self.tokenizer.eos_token_id]
n_context = len(context)
n_target = len(target)
if n_context >= self.max_seq_len:
if not self._warned_context:
warnings.warn(
f'Skipping example because CONTEXT length={n_context} leaves no room ' +\
f'for TARGET tokens because max_seq_len={self.max_seq_len}. ' +\
f'If this causes downstream issues because of inconsistent batch sizes, ' +\
f'consider increasing max_seq_len or using example packing.'
)
self._warned_context = True
continue
if self.format_for_generation:
# When formatting for generation, we need to keep input_ids and
# labels separate. The input_ids (context) will be fed into the
# generator and the labels will be used by the eval metric.
input_ids = context[-self.max_seq_len:]
n_context = len(input_ids)
attention_mask = [1] * n_context
bidirectional_mask = [1] * n_context
# Annoyingly, we need to pad the everything but input_ids
# and attention_mask ourselves
i_pad = [self.tokenizer.pad_token_id
] * (self.max_seq_len - n_target)
z_pad = [0] * (self.max_seq_len - n_context)
if self.tokenizer.padding_side == 'left':
labels = i_pad + target
bidirectional_mask = z_pad + bidirectional_mask
else:
labels = target + i_pad
bidirectional_mask = bidirectional_mask + z_pad
else:
# We need to concatenate the context and target to get the
# full input sequence, cutting off any excess tokens from the
# end of the target
if n_context + n_target > self.max_seq_len:
old_n_target = int(n_target)
n_target = self.max_seq_len - n_context
if not self._warned_target:
warnings.warn(
f'Truncating TARGET sequence of length={old_n_target} to length={n_target}, ' +\
f'so context+target fit max_seq_len={self.max_seq_len}. If truncation is ' +\
f'a problem, consider increasing max_seq_len.')
self._warned_target = True
target = target[-n_target:]
target[-1] = self.tokenizer.eos_token_id
n_total = n_context + n_target
input_ids = context + target
labels = ([_HF_IGNORE_INDEX] * n_context) + target
attention_mask = [1] * n_total
# bidirectional_mask is used by our prefix lm model variants
bidirectional_mask = ([1] * n_context) + ([0] * n_target)
# Annoyingly, we need to pad the everything but input_ids
# and attention_mask ourselves
i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - n_total)
z_pad = [0] * (self.max_seq_len - n_total)
if self.tokenizer.padding_side == 'left':
labels = i_pad + labels
bidirectional_mask = z_pad + bidirectional_mask
else:
labels = labels + i_pad
bidirectional_mask = bidirectional_mask + z_pad
# Update the example
example['input_ids'] = input_ids
example['labels'] = labels
example['attention_mask'] = attention_mask
example['bidirectional_mask'] = bidirectional_mask
processed_examples.append(example)
batch = self.tokenizer.pad(
processed_examples,
padding='max_length',
max_length=self.max_seq_len,
return_tensors='pt',
)
# This logic prevents trimming on at least the first batch
if not (self._allow_pad_trimming and self._seen_first_batch):
self._seen_first_batch = True
return batch
self._seen_first_batch = True
# The batch is ready, but we can trim padding for efficiency
multiple_of = 8
n_non_padding = batch['attention_mask'].sum(dim=1).max()
keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
for k, v in batch.items():
if len(v.shape) < 2:
continue
if k == 'labels' and self.format_for_generation:
continue
if self.tokenizer.padding_side == 'left':
batch[k] = v[:, -keep_tokens:].contiguous()
else:
batch[k] = v[:, :keep_tokens].contiguous()
return batch
def _process_and_batch_encoder_decoder(
self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
# The encoder-decoder case is has some gotchas.
# Steps are explained in comments.
processed_examples = []
for example in examples:
context = ensure_list(example['input_ids'])
target = ensure_list(example['labels'])
# ... first, get rid of any padding that was already applied
context = [t for t in context if t != self.tokenizer.pad_token_id]
target = [t for t in target if t != self.tokenizer.pad_token_id]
# ... second, ensure that the target text ends with an eos tag
if target[-1] != self.tokenizer.eos_token_id:
target = target + [self.tokenizer.eos_token_id]
# ... third, we need to pad labels ourselves. Because HF.
if len(target) < self.max_seq_len:
i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - len(target))
target = target + i_pad
else:
if not self._warned_target:
warnings.warn(
f'Truncating TARGET sequence of length={len(target)} ' +\
f'to max_seq_len={self.max_seq_len}. If truncation is ' +\
f'a problem, consider increasing max_seq_len.')
self._warned_target = True
target = target[:self.max_seq_len -
1] + [self.tokenizer.eos_token_id]
# We might need to truncate the context. Preserve the beginning.
if len(context) > self.max_seq_len:
if not self._warned_context:
warnings.warn(
f'Truncating CONTEXT sequence of length={len(context)} ' +\
f'to max_seq_len={self.max_seq_len}. If truncation is ' +\
f'a problem, consider increasing max_seq_len.')
self._warned_context = True
context = context[:self.max_seq_len -
1] + [self.tokenizer.eos_token_id]
# Back into the example
example['input_ids'] = context
example['attention_mask'] = [1] * len(context)
example['labels'] = target
processed_examples.append(example)
# Batch examples into a single dict (this also pads)
batch = self.tokenizer.pad(
processed_examples,
padding='max_length',
max_length=self.max_seq_len,
return_tensors='pt',
)
# We're still missing decoder_input_ids and decoder_attention_mask
batch['decoder_input_ids'] = torch.cat([
torch.full((len(processed_examples), 1),
self.tokenizer.pad_token_id), batch['labels'][:, :-1]
],
dim=1)
batch['decoder_input_ids'].masked_fill_(
batch['decoder_input_ids'] == _HF_IGNORE_INDEX,
self.tokenizer.pad_token_id)
batch['decoder_attention_mask'] = torch.not_equal(
batch['labels'], _HF_IGNORE_INDEX)
# This logic prevents trimming on at least the first batch
if not (self._allow_pad_trimming and self._seen_first_batch):
self._seen_first_batch = True
return batch
self._seen_first_batch = True
# The batch is now valid, but we can trim padding for efficiency
multiple_of = 8
# (first for the encoder)
n_non_padding = batch['attention_mask'].sum(dim=1).max()
keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
for k in ['input_ids', 'attention_mask']:
batch[k] = batch[k][:, :keep_tokens].contiguous()
# (then for the decoder)
n_non_padding = batch['decoder_attention_mask'].sum(dim=1).max()
keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
for k in ['decoder_input_ids', 'decoder_attention_mask', 'labels']:
batch[k] = batch[k][:, :keep_tokens].contiguous()
return batch
def ensure_list(x: Union[List, torch.Tensor]) -> List:
if isinstance(x, torch.Tensor):
x = list(x.flatten())
assert isinstance(x, list)
return x