Spaces:
Sleeping
Sleeping
from abc import abstractmethod | |
from models.tokenizer import TokenAligner | |
class PTCollator(): | |
def __init__(self, tokenAligner: TokenAligner): | |
self.tokenAligner = tokenAligner | |
def collate(self, dataloader_batch, type = "train") -> dict: | |
if type == "train": | |
return self.collate_train(dataloader_batch) | |
elif type == "test": | |
return self.collate_test(dataloader_batch) | |
elif type == "correct": | |
return self.collate_correct(dataloader_batch) | |
def collate_train(self, dataloader_batch): | |
pass | |
def collate_test(self, dataloader_batch): | |
pass | |
def collate_correct(self, dataloader_batch): | |
pass | |
class DataCollatorForCharacterTransformer(PTCollator): | |
def __init__(self, tokenAligner: TokenAligner): | |
super().__init__(tokenAligner) | |
def collate_train(self, dataloader_batch): | |
noised, labels = [], [] | |
for sample in dataloader_batch: | |
labels.append(sample[0]) | |
noised.append(sample[1]) | |
batch_srcs, batch_tgts, batch_lengths, batch_attention_masks = self.tokenAligner.tokenize_for_transformer_with_tokenization(noised, labels) | |
data = dict() | |
data['batch_src'] = batch_srcs | |
data['batch_tgt'] = batch_tgts | |
data['attn_masks'] = batch_attention_masks | |
data['lengths'] = batch_lengths | |
return data | |
def collate_test(self, dataloader_batch): | |
noised, labels = [], [] | |
for sample in dataloader_batch: | |
labels.append(sample[0]) | |
noised.append(sample[1]) | |
batch_srcs, batch_attention_masks = self.tokenAligner.tokenize_for_transformer_with_tokenization(noised, None) | |
data = dict() | |
data['batch_src'] = batch_srcs | |
data['noised_texts'] = noised | |
data['label_texts'] = labels | |
data['attn_masks'] = batch_attention_masks | |
return data | |
def collate_correct(self, dataloader_batch): | |
noised, labels = [], [] | |
for sample in dataloader_batch: | |
noised.append(sample[1]) | |
batch_srcs, batch_attention_masks= self.tokenAligner.tokenize_for_transformer_with_tokenization(noised) | |
data = dict() | |
data['batch_src'] = batch_srcs | |
data['noised_texts'] = noised | |
data['attn_masks'] = batch_attention_masks | |
return data | |