File size: 4,621 Bytes
c02bdcd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
"""
https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning
"""
from typing import List, Tuple, Optional, Union
import torch
from transformers import BertTokenizerFast
from ..utils import del_all
class Tokenizer:
def __init__(
self,
tokenizer_path: torch.serialization.FILE_LIKE,
):
"""
tokenizer: BertTokenizerFast = torch.load(
tokenizer_path, map_location=device, mmap=True
)
# tokenizer.save_pretrained("asset/tokenizer", legacy_format=False)
"""
tokenizer: BertTokenizerFast = BertTokenizerFast.from_pretrained(tokenizer_path)
self._tokenizer = tokenizer
self.len = len(tokenizer)
self.spk_emb_ids = tokenizer.convert_tokens_to_ids("[spk_emb]")
self.break_0_ids = tokenizer.convert_tokens_to_ids("[break_0]")
self.eos_token = tokenizer.convert_tokens_to_ids("[Ebreak]")
@torch.inference_mode()
def encode(
self,
text: List[str],
num_vq: int,
prompt: Optional[torch.Tensor] = None,
device="cpu",
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
input_ids_lst = []
attention_mask_lst = []
max_input_ids_len = -1
max_attention_mask_len = -1
prompt_size = 0
if prompt is not None:
assert prompt.size(0) == num_vq, "prompt dim 0 must equal to num_vq"
prompt_size = prompt.size(1)
# avoid random speaker embedding of tokenizer in the other dims
for t in text:
x = self._tokenizer.encode_plus(
t, return_tensors="pt", add_special_tokens=False, padding=True
)
input_ids_lst.append(x["input_ids"].squeeze_(0))
attention_mask_lst.append(x["attention_mask"].squeeze_(0))
del_all(x)
ids_sz = input_ids_lst[-1].size(0)
if ids_sz > max_input_ids_len:
max_input_ids_len = ids_sz
attn_sz = attention_mask_lst[-1].size(0)
if attn_sz > max_attention_mask_len:
max_attention_mask_len = attn_sz
if prompt is not None:
max_input_ids_len += prompt_size
max_attention_mask_len += prompt_size
input_ids = torch.zeros(
len(input_ids_lst),
max_input_ids_len,
device=device,
dtype=input_ids_lst[0].dtype,
)
for i in range(len(input_ids_lst)):
input_ids.narrow(0, i, 1).narrow(
1,
max_input_ids_len - prompt_size - input_ids_lst[i].size(0),
input_ids_lst[i].size(0),
).copy_(
input_ids_lst[i]
) # left padding
del_all(input_ids_lst)
attention_mask = torch.zeros(
len(attention_mask_lst),
max_attention_mask_len,
device=device,
dtype=attention_mask_lst[0].dtype,
)
for i in range(len(attention_mask_lst)):
attn = attention_mask.narrow(0, i, 1)
attn.narrow(
1,
max_attention_mask_len - prompt_size - attention_mask_lst[i].size(0),
attention_mask_lst[i].size(0),
).copy_(
attention_mask_lst[i]
) # left padding
if prompt_size > 0:
attn.narrow(
1,
max_attention_mask_len - prompt_size,
prompt_size,
).fill_(1)
del_all(attention_mask_lst)
text_mask = attention_mask.bool()
new_input_ids = input_ids.unsqueeze_(-1).expand(-1, -1, num_vq).clone()
del input_ids
if prompt_size > 0:
text_mask.narrow(1, max_input_ids_len - prompt_size, prompt_size).fill_(0)
prompt_t = prompt.t().unsqueeze_(0).expand(new_input_ids.size(0), -1, -1)
new_input_ids.narrow(
1,
max_input_ids_len - prompt_size,
prompt_size,
).copy_(prompt_t)
del prompt_t
return new_input_ids, attention_mask, text_mask
@torch.inference_mode
def decode(
self,
sequences: Union[List[int], List[List[int]]],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = None,
**kwargs,
):
return self._tokenizer.batch_decode(
sequences, skip_special_tokens, clean_up_tokenization_spaces, **kwargs
)
|