File size: 4,621 Bytes
c02bdcd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"
"""
https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning
"""

from typing import List, Tuple, Optional, Union

import torch
from transformers import BertTokenizerFast

from ..utils import del_all


class Tokenizer:
    def __init__(
        self,
        tokenizer_path: torch.serialization.FILE_LIKE,
    ):
        """
        tokenizer: BertTokenizerFast = torch.load(
            tokenizer_path, map_location=device, mmap=True
        )
        # tokenizer.save_pretrained("asset/tokenizer", legacy_format=False)
        """
        tokenizer: BertTokenizerFast = BertTokenizerFast.from_pretrained(tokenizer_path)
        self._tokenizer = tokenizer

        self.len = len(tokenizer)
        self.spk_emb_ids = tokenizer.convert_tokens_to_ids("[spk_emb]")
        self.break_0_ids = tokenizer.convert_tokens_to_ids("[break_0]")
        self.eos_token = tokenizer.convert_tokens_to_ids("[Ebreak]")

    @torch.inference_mode()
    def encode(
        self,
        text: List[str],
        num_vq: int,
        prompt: Optional[torch.Tensor] = None,
        device="cpu",
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

        input_ids_lst = []
        attention_mask_lst = []
        max_input_ids_len = -1
        max_attention_mask_len = -1
        prompt_size = 0

        if prompt is not None:
            assert prompt.size(0) == num_vq, "prompt dim 0 must equal to num_vq"
            prompt_size = prompt.size(1)

        # avoid random speaker embedding of tokenizer in the other dims
        for t in text:
            x = self._tokenizer.encode_plus(
                t, return_tensors="pt", add_special_tokens=False, padding=True
            )
            input_ids_lst.append(x["input_ids"].squeeze_(0))
            attention_mask_lst.append(x["attention_mask"].squeeze_(0))
            del_all(x)
            ids_sz = input_ids_lst[-1].size(0)
            if ids_sz > max_input_ids_len:
                max_input_ids_len = ids_sz
            attn_sz = attention_mask_lst[-1].size(0)
            if attn_sz > max_attention_mask_len:
                max_attention_mask_len = attn_sz

        if prompt is not None:
            max_input_ids_len += prompt_size
            max_attention_mask_len += prompt_size

        input_ids = torch.zeros(
            len(input_ids_lst),
            max_input_ids_len,
            device=device,
            dtype=input_ids_lst[0].dtype,
        )
        for i in range(len(input_ids_lst)):
            input_ids.narrow(0, i, 1).narrow(
                1,
                max_input_ids_len - prompt_size - input_ids_lst[i].size(0),
                input_ids_lst[i].size(0),
            ).copy_(
                input_ids_lst[i]
            )  # left padding
        del_all(input_ids_lst)

        attention_mask = torch.zeros(
            len(attention_mask_lst),
            max_attention_mask_len,
            device=device,
            dtype=attention_mask_lst[0].dtype,
        )
        for i in range(len(attention_mask_lst)):
            attn = attention_mask.narrow(0, i, 1)
            attn.narrow(
                1,
                max_attention_mask_len - prompt_size - attention_mask_lst[i].size(0),
                attention_mask_lst[i].size(0),
            ).copy_(
                attention_mask_lst[i]
            )  # left padding
            if prompt_size > 0:
                attn.narrow(
                    1,
                    max_attention_mask_len - prompt_size,
                    prompt_size,
                ).fill_(1)
        del_all(attention_mask_lst)

        text_mask = attention_mask.bool()
        new_input_ids = input_ids.unsqueeze_(-1).expand(-1, -1, num_vq).clone()
        del input_ids

        if prompt_size > 0:
            text_mask.narrow(1, max_input_ids_len - prompt_size, prompt_size).fill_(0)
            prompt_t = prompt.t().unsqueeze_(0).expand(new_input_ids.size(0), -1, -1)
            new_input_ids.narrow(
                1,
                max_input_ids_len - prompt_size,
                prompt_size,
            ).copy_(prompt_t)
            del prompt_t

        return new_input_ids, attention_mask, text_mask

    @torch.inference_mode
    def decode(
        self,
        sequences: Union[List[int], List[List[int]]],
        skip_special_tokens: bool = False,
        clean_up_tokenization_spaces: bool = None,
        **kwargs,
    ):
        return self._tokenizer.batch_decode(
            sequences, skip_special_tokens, clean_up_tokenization_spaces, **kwargs
        )