File size: 2,853 Bytes
98e2ea5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
from typing import List, Dict, Union
from transformers import PreTrainedTokenizerFast
from torch import Tensor


class TensorizeDataset:
    def __init__(
        self, tokenizer: PreTrainedTokenizerFast, remove_singletons: bool = False
    ) -> None:
        self.tokenizer = tokenizer
        self.remove_singletons = remove_singletons
        self.device = torch.device("cpu")

    def tensorize_data(
        self, split_data: List[Dict], training: bool = False
    ) -> List[Dict]:
        tensorized_data = []
        for document in split_data:
            tensorized_data.append(
                self.tensorize_instance_independent(document, training=training)
            )

        return tensorized_data

    def process_segment(self, segment: List) -> List:
        if self.tokenizer.sep_token_id is None:
            # print("SentencePiece Tokenizer")
            return [self.tokenizer.bos_token_id] + segment + [self.tokenizer.eos_token_id]
        else:
            # print("WordPiece Tokenizer")
            return [self.tokenizer.cls_token_id] + segment + [self.tokenizer.sep_token_id]
    
    def tensorize_instance_independent(
        self, document: Dict, training: bool = False
    ) -> Dict:
        segments: List[List[int]] = document["sentences"]
        clusters: List = document.get("clusters", [])
        ext_predicted_mentions: List = document.get("ext_predicted_mentions", [])
        sentence_map: List[int] = document["sentence_map"]
        subtoken_map: List[int] = document["subtoken_map"]
        representatives: List = document.get("representatives", [])
        representative_embs: List = document.get("representative_embs", [])
        
        tensorized_sent: List[Tensor] = [
            torch.unsqueeze(
                torch.tensor(self.process_segment(sent), device=self.device), dim=0
            )
            for sent in segments
        ]

        sent_len_list = [len(sent) for sent in segments]
        output_dict = {
            "tensorized_sent": tensorized_sent,
            "sentences": segments,
            "sent_len_list": sent_len_list,
            "doc_key": document.get("doc_key", None),
            "clusters": clusters,
            "ext_predicted_mentions": ext_predicted_mentions,
            "subtoken_map": subtoken_map,
            "sentence_map": torch.tensor(sentence_map, device=self.device),
            "representatives": representatives,
            "representative_embs": representative_embs,
        }

        # Pass along other metadata
        for key in document:
            if key not in output_dict:
                output_dict[key] = document[key]

        if self.remove_singletons:
            output_dict["clusters"] = [
                cluster for cluster in output_dict["clusters"] if len(cluster) > 1
            ]

        return output_dict