|
from dataclasses import dataclass |
|
from typing import Any, Dict, List, Sequence, Tuple |
|
|
|
import torch |
|
from transformers import DataCollatorForSeq2Seq |
|
|
|
|
|
@dataclass |
|
class DPODataCollatorWithPadding(DataCollatorForSeq2Seq): |
|
r""" |
|
Data collator for pairwise data. |
|
""" |
|
|
|
def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor: |
|
padded_labels = [] |
|
for feature, (prompt_len, answer_len) in zip(batch, positions): |
|
if self.tokenizer.padding_side == "left": |
|
start, end = feature.size(0) - answer_len, feature.size(0) |
|
else: |
|
start, end = prompt_len, prompt_len + answer_len |
|
padded_tensor = self.label_pad_token_id * torch.ones_like(feature) |
|
padded_tensor[start:end] = feature[start:end] |
|
padded_labels.append(padded_tensor) |
|
return torch.stack(padded_labels, dim=0).contiguous() |
|
|
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: |
|
r""" |
|
Pads batched data to the longest sequence in the batch. |
|
|
|
We generate 2 * n examples where the first n examples represent chosen examples and |
|
the last n examples represent rejected examples. |
|
""" |
|
concatenated_features = [] |
|
label_positions = [] |
|
for key in ("chosen_ids", "rejected_ids"): |
|
for feature in features: |
|
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key]) |
|
concatenated_features.append( |
|
{ |
|
"input_ids": feature["prompt_ids"] + feature[key], |
|
"attention_mask": [1] * (prompt_len + answer_len), |
|
} |
|
) |
|
label_positions.append((prompt_len, answer_len)) |
|
|
|
batch = self.tokenizer.pad( |
|
concatenated_features, |
|
padding=self.padding, |
|
max_length=self.max_length, |
|
pad_to_multiple_of=self.pad_to_multiple_of, |
|
return_tensors=self.return_tensors, |
|
) |
|
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions) |
|
return batch |
|
|