Gregniuki commited on
Commit
473b4d2
1 Parent(s): 59d0e09

Delete model/utils.py

Browse files
Files changed (1) hide show
  1. model/utils.py +0 -185
model/utils.py DELETED
@@ -1,185 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import os
4
- import random
5
- from collections import defaultdict
6
- from importlib.resources import files
7
-
8
- import torch
9
- from torch.nn.utils.rnn import pad_sequence
10
-
11
- import jieba
12
- from pypinyin import lazy_pinyin, Style
13
-
14
-
15
- # seed everything
16
-
17
-
18
- def seed_everything(seed=0):
19
- random.seed(seed)
20
- os.environ["PYTHONHASHSEED"] = str(seed)
21
- torch.manual_seed(seed)
22
- torch.cuda.manual_seed(seed)
23
- torch.cuda.manual_seed_all(seed)
24
- torch.backends.cudnn.deterministic = True
25
- torch.backends.cudnn.benchmark = False
26
-
27
-
28
- # helpers
29
-
30
-
31
- def exists(v):
32
- return v is not None
33
-
34
-
35
- def default(v, d):
36
- return v if exists(v) else d
37
-
38
-
39
- # tensor helpers
40
-
41
-
42
- def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821
43
- if not exists(length):
44
- length = t.amax()
45
-
46
- seq = torch.arange(length, device=t.device)
47
- return seq[None, :] < t[:, None]
48
-
49
-
50
- def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821
51
- max_seq_len = seq_len.max().item()
52
- seq = torch.arange(max_seq_len, device=start.device).long()
53
- start_mask = seq[None, :] >= start[:, None]
54
- end_mask = seq[None, :] < end[:, None]
55
- return start_mask & end_mask
56
-
57
-
58
- def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821
59
- lengths = (frac_lengths * seq_len).long()
60
- max_start = seq_len - lengths
61
-
62
- rand = torch.rand_like(frac_lengths)
63
- start = (max_start * rand).long().clamp(min=0)
64
- end = start + lengths
65
-
66
- return mask_from_start_end_indices(seq_len, start, end)
67
-
68
-
69
- def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722
70
- if not exists(mask):
71
- return t.mean(dim=1)
72
-
73
- t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device))
74
- num = t.sum(dim=1)
75
- den = mask.float().sum(dim=1)
76
-
77
- return num / den.clamp(min=1.0)
78
-
79
-
80
- # simple utf-8 tokenizer, since paper went character based
81
- def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
82
- list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
83
- text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
84
- return text
85
-
86
-
87
- # char tokenizer, based on custom dataset's extracted .txt file
88
- def list_str_to_idx(
89
- text: list[str] | list[list[str]],
90
- vocab_char_map: dict[str, int], # {char: idx}
91
- padding_value=-1,
92
- ) -> int["b nt"]: # noqa: F722
93
- list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
94
- text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
95
- return text
96
-
97
-
98
- # Get tokenizer
99
-
100
-
101
- def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
102
- """
103
- tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
104
- - "char" for char-wise tokenizer, need .txt vocab_file
105
- - "byte" for utf-8 tokenizer
106
- - "custom" if you're directly passing in a path to the vocab.txt you want to use
107
- vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
108
- - if use "char", derived from unfiltered character & symbol counts of custom dataset
109
- - if use "byte", set to 256 (unicode byte range)
110
- """
111
- if tokenizer in ["pinyin", "char"]:
112
- tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt")
113
- with open(tokenizer_path, "r", encoding="utf-8") as f:
114
- vocab_char_map = {}
115
- for i, char in enumerate(f):
116
- vocab_char_map[char[:-1]] = i
117
- vocab_size = len(vocab_char_map)
118
- assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
119
-
120
- elif tokenizer == "byte":
121
- vocab_char_map = None
122
- vocab_size = 256
123
-
124
- elif tokenizer == "custom":
125
- with open(dataset_name, "r", encoding="utf-8") as f:
126
- vocab_char_map = {}
127
- for i, char in enumerate(f):
128
- vocab_char_map[char[:-1]] = i
129
- vocab_size = len(vocab_char_map)
130
-
131
- return vocab_char_map, vocab_size
132
-
133
-
134
- # convert char to pinyin
135
-
136
-
137
- def convert_char_to_pinyin(text_list, polyphone=True):
138
- final_text_list = []
139
- god_knows_why_en_testset_contains_zh_quote = str.maketrans(
140
- {"“": '"', "”": '"', "‘": "'", "’": "'"}
141
- ) # in case librispeech (orig no-pc) test-clean
142
- custom_trans = str.maketrans({";": ","}) # add custom trans here, to address oov
143
- for text in text_list:
144
- char_list = []
145
- text = text.translate(god_knows_why_en_testset_contains_zh_quote)
146
- text = text.translate(custom_trans)
147
- for seg in jieba.cut(text):
148
- seg_byte_len = len(bytes(seg, "UTF-8"))
149
- if seg_byte_len == len(seg): # if pure alphabets and symbols
150
- if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
151
- char_list.append(" ")
152
- char_list.extend(seg)
153
- elif polyphone and seg_byte_len == 3 * len(seg): # if pure chinese characters
154
- seg = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
155
- for c in seg:
156
- if c not in "。,、;:?!《》【】—…":
157
- char_list.append(" ")
158
- char_list.append(c)
159
- else: # if mixed chinese characters, alphabets and symbols
160
- for c in seg:
161
- if ord(c) < 256:
162
- char_list.extend(c)
163
- else:
164
- if c not in "。,、;:?!《》【】—…":
165
- char_list.append(" ")
166
- char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
167
- else: # if is zh punc
168
- char_list.append(c)
169
- final_text_list.append(char_list)
170
-
171
- return final_text_list
172
-
173
-
174
- # filter func for dirty data with many repetitions
175
-
176
-
177
- def repetition_found(text, length=2, tolerance=10):
178
- pattern_count = defaultdict(int)
179
- for i in range(len(text) - length + 1):
180
- pattern = text[i : i + length]
181
- pattern_count[pattern] += 1
182
- for pattern, count in pattern_count.items():
183
- if count > tolerance:
184
- return True
185
- return False