import torch from typing import List, Dict, Union from transformers import PreTrainedTokenizerFast from torch import Tensor class TensorizeDataset: def __init__( self, tokenizer: PreTrainedTokenizerFast, remove_singletons: bool = False ) -> None: self.tokenizer = tokenizer self.remove_singletons = remove_singletons self.device = torch.device("cpu") def tensorize_data( self, split_data: List[Dict], training: bool = False ) -> List[Dict]: tensorized_data = [] for document in split_data: tensorized_data.append( self.tensorize_instance_independent(document, training=training) ) return tensorized_data def process_segment(self, segment: List) -> List: if self.tokenizer.sep_token_id is None: # print("SentencePiece Tokenizer") return [self.tokenizer.bos_token_id] + segment + [self.tokenizer.eos_token_id] else: # print("WordPiece Tokenizer") return [self.tokenizer.cls_token_id] + segment + [self.tokenizer.sep_token_id] def tensorize_instance_independent( self, document: Dict, training: bool = False ) -> Dict: segments: List[List[int]] = document["sentences"] clusters: List = document.get("clusters", []) ext_predicted_mentions: List = document.get("ext_predicted_mentions", []) sentence_map: List[int] = document["sentence_map"] subtoken_map: List[int] = document["subtoken_map"] representatives: List = document.get("representatives", []) representative_embs: List = document.get("representative_embs", []) tensorized_sent: List[Tensor] = [ torch.unsqueeze( torch.tensor(self.process_segment(sent), device=self.device), dim=0 ) for sent in segments ] sent_len_list = [len(sent) for sent in segments] output_dict = { "tensorized_sent": tensorized_sent, "sentences": segments, "sent_len_list": sent_len_list, "doc_key": document.get("doc_key", None), "clusters": clusters, "ext_predicted_mentions": ext_predicted_mentions, "subtoken_map": subtoken_map, "sentence_map": torch.tensor(sentence_map, device=self.device), "representatives": representatives, "representative_embs": representative_embs, } # Pass along other metadata for key in document: if key not in output_dict: output_dict[key] = document[key] if self.remove_singletons: output_dict["clusters"] = [ cluster for cluster in output_dict["clusters"] if len(cluster) > 1 ] return output_dict