vuxuanhoan commited on
Commit
7e40419
1 Parent(s): a9a13dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -222
app.py CHANGED
@@ -1,231 +1,62 @@
1
- import torch # isort:skip
2
-
3
- torch.manual_seed(42)
4
  import json
5
- import re
6
- import unicodedata
7
- from types import SimpleNamespace
8
-
9
  import gradio as gr
10
- import numpy as np
11
- import regex
12
-
13
- from models import DurationNet, SynthesizerTrn
14
-
15
- title = "LightSpeed: Vietnamese Male Voice TTS"
16
- description = "Vietnam Male Voice TTS."
17
- config_file = "config.json"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  duration_model_path = "vbx_duration_model.pth"
19
- lightspeed_model_path = "gen_619k.pth"
20
- phone_set_file = "vbx_phone_set.json"
21
- device = "cpu"
22
- #device = "cuda" if torch.cuda.is_available() else "cpu"
23
- with open(config_file, "rb") as f:
24
- hps = json.load(f, object_hook=lambda x: SimpleNamespace(**x))
25
-
26
- # load phone set json file
27
- with open(phone_set_file, "r") as f:
28
- phone_set = json.load(f)
29
-
30
- assert phone_set[0][1:-1] == "SEP"
31
- assert "sil" in phone_set
32
- sil_idx = phone_set.index("sil")
33
-
34
- space_re = regex.compile(r"\s+")
35
- number_re = regex.compile("([0-9]+)")
36
- digits = ["không", "một", "hai", "ba", "bốn", "năm", "sáu", "bảy", "tám", "chín"]
37
- num_re = regex.compile(r"([0-9.,]*[0-9])")
38
- alphabet = "aàáảãạăằắẳẵặâầấẩẫậeèéẻẽẹêềếểễệiìíỉĩịoòóỏõọôồốổỗộơờớởỡợuùúủũụưừứửữựyỳýỷỹỵbcdđghklmnpqrstvx"
39
- keep_text_and_num_re = regex.compile(rf"[^\s{alphabet}.,0-9]")
40
- keep_text_re = regex.compile(rf"[^\s{alphabet}]")
41
-
42
-
43
- def read_number(num: str) -> str:
44
- if len(num) == 1:
45
- return digits[int(num)]
46
- elif len(num) == 2 and num.isdigit():
47
- n = int(num)
48
- end = digits[n % 10]
49
- if n == 10:
50
- return "mười"
51
- if n % 10 == 5:
52
- end = "lăm"
53
- if n % 10 == 0:
54
- return digits[n // 10] + " mươi"
55
- elif n < 20:
56
- return "mười " + end
57
- else:
58
- if n % 10 == 1:
59
- end = "mốt"
60
- return digits[n // 10] + " mươi " + end
61
- elif len(num) == 3 and num.isdigit():
62
- n = int(num)
63
- if n % 100 == 0:
64
- return digits[n // 100] + " trăm"
65
- elif num[1] == "0":
66
- return digits[n // 100] + " trăm lẻ " + digits[n % 100]
67
- else:
68
- return digits[n // 100] + " trăm " + read_number(num[1:])
69
- elif len(num) >= 4 and len(num) <= 6 and num.isdigit():
70
- n = int(num)
71
- n1 = n // 1000
72
- return read_number(str(n1)) + " ngàn " + read_number(num[-3:])
73
- elif "," in num:
74
- n1, n2 = num.split(",")
75
- return read_number(n1) + " phẩy " + read_number(n2)
76
- elif "." in num:
77
- parts = num.split(".")
78
- if len(parts) == 2:
79
- if parts[1] == "000":
80
- return read_number(parts[0]) + " ngàn"
81
- elif parts[1].startswith("00"):
82
- end = digits[int(parts[1][2:])]
83
- return read_number(parts[0]) + " ngàn lẻ " + end
84
- else:
85
- return read_number(parts[0]) + " ngàn " + read_number(parts[1])
86
- elif len(parts) == 3:
87
- return (
88
- read_number(parts[0])
89
- + " triệu "
90
- + read_number(parts[1])
91
- + " ngàn "
92
- + read_number(parts[2])
93
- )
94
- return num
95
-
96
-
97
- def text_to_phone_idx(text):
98
- # lowercase
99
- text = text.lower()
100
- # unicode normalize
101
- text = unicodedata.normalize("NFKC", text)
102
- text = text.replace(".", " . ")
103
- text = text.replace(",", " , ")
104
- text = text.replace(";", " ; ")
105
- text = text.replace(":", " : ")
106
- text = text.replace("!", " ! ")
107
- text = text.replace("?", " ? ")
108
- text = text.replace("(", " ( ")
109
-
110
- text = num_re.sub(r" \1 ", text)
111
- words = text.split()
112
- words = [read_number(w) if num_re.fullmatch(w) else w for w in words]
113
- text = " ".join(words)
114
-
115
- # remove redundant spaces
116
- text = re.sub(r"\s+", " ", text)
117
- # remove leading and trailing spaces
118
- text = text.strip()
119
- # convert words to phone indices
120
- tokens = []
121
- for c in text:
122
- # if c is "," or ".", add <sil> phone
123
- if c in ":,.!?;(":
124
- tokens.append(sil_idx)
125
- elif c in phone_set:
126
- tokens.append(phone_set.index(c))
127
- elif c == " ":
128
- # add <sep> phone
129
- tokens.append(0)
130
- if tokens[0] != sil_idx:
131
- # insert <sil> phone at the beginning
132
- tokens = [sil_idx, 0] + tokens
133
- if tokens[-1] != sil_idx:
134
- tokens = tokens + [0, sil_idx]
135
- return tokens
136
-
137
-
138
- def text_to_speech(duration_net, generator, text):
139
- # prevent too long text
140
- if len(text) > 500:
141
- text = text[:500]
142
-
143
- phone_idx = text_to_phone_idx(text)
144
- batch = {
145
- "phone_idx": np.array([phone_idx]),
146
- "phone_length": np.array([len(phone_idx)]),
147
- }
148
-
149
- # predict phoneme duration
150
- phone_length = torch.from_numpy(batch["phone_length"].copy()).long().to(device)
151
- phone_idx = torch.from_numpy(batch["phone_idx"].copy()).long().to(device)
152
- with torch.inference_mode():
153
- phone_duration = duration_net(phone_idx, phone_length)[:, :, 0] * 1000
154
- phone_duration = torch.where(
155
- phone_idx == sil_idx, torch.clamp_min(phone_duration, 200), phone_duration
156
- )
157
- phone_duration = torch.where(phone_idx == 0, 0, phone_duration)
158
-
159
- # generate waveform
160
- end_time = torch.cumsum(phone_duration, dim=-1)
161
- start_time = end_time - phone_duration
162
- start_frame = start_time / 1000 * hps.data.sampling_rate / hps.data.hop_length
163
- end_frame = end_time / 1000 * hps.data.sampling_rate / hps.data.hop_length
164
- spec_length = end_frame.max(dim=-1).values
165
- pos = torch.arange(0, spec_length.item(), device=device)
166
- attn = torch.logical_and(
167
- pos[None, :, None] >= start_frame[:, None, :],
168
- pos[None, :, None] < end_frame[:, None, :],
169
- ).float()
170
- with torch.inference_mode():
171
- y_hat = generator.infer(
172
- phone_idx, phone_length, spec_length, attn, max_len=None, noise_scale=0.667
173
- )[0]
174
- wave = y_hat[0, 0].data.cpu().numpy()
175
- return (wave * (2**15)).astype(np.int16)
176
 
 
 
 
177
 
178
- def load_models():
179
- duration_net = DurationNet(hps.data.vocab_size, 64, 4).to(device)
180
- duration_net.load_state_dict(torch.load(duration_model_path, map_location=device))
181
- duration_net = duration_net.eval()
182
- generator = SynthesizerTrn(
183
- hps.data.vocab_size,
184
- hps.data.filter_length // 2 + 1,
185
- hps.train.segment_size // hps.data.hop_length,
186
- **vars(hps.model),
187
- ).to(device)
188
- del generator.enc_q
189
- ckpt = torch.load(lightspeed_model_path, map_location=device)
190
- params = {}
191
- for k, v in ckpt["net_g"].items():
192
- k = k[7:] if k.startswith("module.") else k
193
- params[k] = v
194
- generator.load_state_dict(params, strict=False)
195
- del ckpt, params
196
- generator = generator.eval()
197
- return duration_net, generator
198
 
 
 
 
199
 
200
- def speak(text):
201
- duration_net, generator = load_models()
202
- paragraphs = text.split("\n")
203
- clips = [] # list of audio clips
204
- # silence = np.zeros(hps.data.sampling_rate // 4)
205
- for paragraph in paragraphs:
206
- paragraph = paragraph.strip()
207
- if paragraph == "":
208
- continue
209
- clips.append(text_to_speech(duration_net, generator, paragraph))
210
- # clips.append(silence)
211
- y = np.concatenate(clips)
212
- return hps.data.sampling_rate, y
213
 
 
 
214
 
215
- gr.Interface(
216
- fn=speak,
217
- inputs="text",
218
- outputs="audio",
219
- title=title,
220
- examples=[
221
- "Trăm năm trong cõi người ta, chữ tài chữ mệnh khéo là ghét nhau.",
222
- "Đoạn trường tân thanh, thường được biết đến với cái tên đơn giản là Truyện Kiều, là một truyện thơ của đại thi hào Nguyễn Du",
223
- "Lục Vân Tiên quê ở huyện Đông Thành, khôi ngô tuấn tú, tài kiêm văn võ. Nghe tin triều đình mở khoa thi, Vân Tiên từ giã thầy xuống núi đua tài.",
224
- "Lê Quý Đôn, tên thuở nhỏ là Lê Danh Phương, là vị quan thời Lê trung hưng, cũng là nhà thơ và được mệnh danh là nhà bác học lớn của Việt Nam trong thời phong kiến",
225
- "Tất cả mọi người đều sinh ra có quyền bình đẳng. Tạo hóa cho họ những quyền không ai có thể xâm phạm được; trong những quyền ấy, có quyền được sống, quyền tự do và quyền mưu cầu hạnh phúc.",
226
- ],
227
- description=description,
228
- theme="default",
229
- allow_screenshot=False,
230
- allow_flagging="never",
231
- ).launch(debug=False)
 
 
 
 
1
  import json
2
+ import torch
 
 
 
3
  import gradio as gr
4
+ from torch import nn
5
+ import soundfile as sf
6
+
7
+ # Định nghĩa lớp mô hình TTS
8
+ class TTSModel(nn.Module):
9
+ def __init__(self, config):
10
+ super().__init__()
11
+ # Khởi tạo các lớp của mô hình TTS ở đây
12
+ # Giả định bạn đã xác định cách xây dựng mô hình dựa trên cấu hình
13
+ self.config = config
14
+ # Các thành phần khác của mô hình sẽ được thêm vào đây.
15
+
16
+ def forward(self, inputs):
17
+ # Logic của mô hình để chuyển đổi văn bản thành giọng nói
18
+ # Giả định rằng bạn đã thực hiện điều này
19
+ return torch.zeros(22050) # Trả về một tensor âm thanh giả định
20
+
21
+ # Hàm tiền xử lý văn bản
22
+ def preprocess_text(text):
23
+ # Chuyển đổi văn bản thành dạng số (encoding)
24
+ return text # Đây chỉ là một ví dụ đơn giản
25
+
26
+ # Hàm chuyển đổi văn bản thành giọng nói
27
+ def text_to_speech(text):
28
+ inputs = preprocess_text(text)
29
+
30
+ with torch.no_grad():
31
+ audio_output = model(inputs)
32
+
33
+ # Giả sử rằng audio_output là một tensor âm thanh
34
+ return audio_output.numpy()
35
+
36
+ # Tải cấu hình và trọng số mô hình
37
+ config_path = "config.json"
38
  duration_model_path = "vbx_duration_model.pth"
39
+ generation_model_path = "gen_619k.pth"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ # Tải cấu hình
42
+ with open(config_path, 'r') as f:
43
+ config = json.load(f)
44
 
45
+ # Tạo mô hình
46
+ model = TTSModel(config)
47
+ model.eval() # Chuyển mô hình về chế độ đánh giá
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ # Tải trọng số mô hình
50
+ model.load_state_dict(torch.load(duration_model_path, map_location=torch.device('cpu')))
51
+ model.load_state_dict(torch.load(generation_model_path, map_location=torch.device('cpu')))
52
 
53
+ # Xây dựng giao diện Gradio
54
+ def infer(text):
55
+ audio = text_to_speech(text)
56
+ sf.write('output.wav', audio, 22050) # Lưu âm thanh vào tệp WAV
57
+ return 'output.wav'
 
 
 
 
 
 
 
 
58
 
59
+ iface = gr.Interface(fn=infer, inputs="text", outputs="audio", title="Text to Speech",
60
+ description="Chuyển đổi văn bản tiếng Việt thành giọng nói.")
61
 
62
+ iface.launch()