|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
from . import BaseWrapperDataset, data_utils |
|
|
|
|
|
class AddTargetDataset(BaseWrapperDataset): |
|
def __init__( |
|
self, |
|
dataset, |
|
labels, |
|
pad, |
|
eos, |
|
batch_targets, |
|
process_label=None, |
|
add_to_input=False, |
|
): |
|
super().__init__(dataset) |
|
self.labels = labels |
|
self.batch_targets = batch_targets |
|
self.pad = pad |
|
self.eos = eos |
|
self.process_label = process_label |
|
self.add_to_input = add_to_input |
|
|
|
def get_label(self, index): |
|
return ( |
|
self.labels[index] |
|
if self.process_label is None |
|
else self.process_label(self.labels[index]) |
|
) |
|
|
|
def __getitem__(self, index): |
|
item = self.dataset[index] |
|
item["label"] = self.get_label(index) |
|
return item |
|
|
|
def size(self, index): |
|
sz = self.dataset.size(index) |
|
own_sz = len(self.get_label(index)) |
|
return (sz, own_sz) |
|
|
|
def collater(self, samples): |
|
collated = self.dataset.collater(samples) |
|
if len(collated) == 0: |
|
return collated |
|
indices = set(collated["id"].tolist()) |
|
target = [s["label"] for s in samples if s["id"] in indices] |
|
|
|
if self.batch_targets: |
|
collated["target_lengths"] = torch.LongTensor([len(t) for t in target]) |
|
target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False) |
|
collated["ntokens"] = collated["target_lengths"].sum().item() |
|
else: |
|
collated["ntokens"] = sum([len(t) for t in target]) |
|
|
|
collated["target"] = target |
|
|
|
if self.add_to_input: |
|
eos = target.new_full((target.size(0), 1), self.eos) |
|
collated["target"] = torch.cat([target, eos], dim=-1).long() |
|
collated["net_input"]["prev_output_tokens"] = torch.cat( |
|
[eos, target], dim=-1 |
|
).long() |
|
collated["ntokens"] += target.size(0) |
|
return collated |
|
|