File size: 2,242 Bytes
9fab08e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import sys
sys.path.append("..")
from dataset.vocab import Vocab
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer


class TokenAligner():
    def __init__(self, tokenizer: AutoTokenizer, vocab: Vocab):
        self.tokenizer = tokenizer
        self.vocab = vocab
    
    """
    params:
        text  ----  str
    """
    def _char_tokenize(self, text):
        characters = list(text)
        tokens = [ token + "@@" if i < len(characters) - 1 and characters[i + 1] != " " else token for i, token in enumerate(characters)]
        tokens = [token for token in tokens if token not in [" @@", " "]]
        encoded = self.tokenizer.encode_plus(tokens, return_tensors = "pt")
        token_ids = encoded['input_ids'].squeeze(0)
        attn_mask = encoded['attention_mask'].squeeze(0)
        return tokens, token_ids, attn_mask
    
    def char_tokenize(self, batch_texts):
        doc = dict()
        doc['tokens'] = []
        doc['token_ids'] = []
        doc['attention_mask'] = []
        for text in batch_texts:
            tokens, token_ids, attn_mask = self._char_tokenize(text)
            doc['tokens'].append(tokens)
            doc['token_ids'].append(token_ids)
            doc['attention_mask'].append(attn_mask)
        return doc

    def tokenize_for_transformer_with_tokenization(self, batch_noised_text, batch_label_texts = None):
        docs = self.char_tokenize(batch_noised_text)
        batch_srcs = docs['token_ids']
        batch_attention_masks = docs['attention_mask']

        batch_attention_masks = pad_sequence(batch_attention_masks , 
            batch_first=True, padding_value=0)

        batch_srcs = pad_sequence(batch_srcs , 
            batch_first=True, padding_value=self.tokenizer.pad_token_id)
            
        if batch_label_texts != None:
            batch_lengths = [len(self.tokenizer.tokenize(text)) for text in batch_label_texts]
            batch_tgts = self.tokenizer.batch_encode_plus(batch_label_texts, max_length = 512, 
                    truncation = True, padding=True, return_tensors="pt")['input_ids']
            return batch_srcs, batch_tgts, batch_lengths, batch_attention_masks
        
        return batch_srcs, batch_attention_masks