Spaces:
Sleeping
Sleeping
import torch | |
from typing import List, Dict, Union | |
from transformers import PreTrainedTokenizerFast | |
from torch import Tensor | |
class TensorizeDataset: | |
def __init__( | |
self, tokenizer: PreTrainedTokenizerFast, remove_singletons: bool = False | |
) -> None: | |
self.tokenizer = tokenizer | |
self.remove_singletons = remove_singletons | |
self.device = torch.device("cpu") | |
def tensorize_data( | |
self, split_data: List[Dict], training: bool = False | |
) -> List[Dict]: | |
tensorized_data = [] | |
for document in split_data: | |
tensorized_data.append( | |
self.tensorize_instance_independent(document, training=training) | |
) | |
return tensorized_data | |
def process_segment(self, segment: List) -> List: | |
if self.tokenizer.sep_token_id is None: | |
# print("SentencePiece Tokenizer") | |
return [self.tokenizer.bos_token_id] + segment + [self.tokenizer.eos_token_id] | |
else: | |
# print("WordPiece Tokenizer") | |
return [self.tokenizer.cls_token_id] + segment + [self.tokenizer.sep_token_id] | |
def tensorize_instance_independent( | |
self, document: Dict, training: bool = False | |
) -> Dict: | |
segments: List[List[int]] = document["sentences"] | |
clusters: List = document.get("clusters", []) | |
ext_predicted_mentions: List = document.get("ext_predicted_mentions", []) | |
sentence_map: List[int] = document["sentence_map"] | |
subtoken_map: List[int] = document["subtoken_map"] | |
representatives: List = document.get("representatives", []) | |
representative_embs: List = document.get("representative_embs", []) | |
tensorized_sent: List[Tensor] = [ | |
torch.unsqueeze( | |
torch.tensor(self.process_segment(sent), device=self.device), dim=0 | |
) | |
for sent in segments | |
] | |
sent_len_list = [len(sent) for sent in segments] | |
output_dict = { | |
"tensorized_sent": tensorized_sent, | |
"sentences": segments, | |
"sent_len_list": sent_len_list, | |
"doc_key": document.get("doc_key", None), | |
"clusters": clusters, | |
"ext_predicted_mentions": ext_predicted_mentions, | |
"subtoken_map": subtoken_map, | |
"sentence_map": torch.tensor(sentence_map, device=self.device), | |
"representatives": representatives, | |
"representative_embs": representative_embs, | |
} | |
# Pass along other metadata | |
for key in document: | |
if key not in output_dict: | |
output_dict[key] = document[key] | |
if self.remove_singletons: | |
output_dict["clusters"] = [ | |
cluster for cluster in output_dict["clusters"] if len(cluster) > 1 | |
] | |
return output_dict | |