le quy don commited on
Commit
8998055
·
verified ·
1 Parent(s): 1159cf1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +285 -225
app.py CHANGED
@@ -1,258 +1,318 @@
1
- import spaces
2
- import torch
3
  import os
4
- os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
5
- import gradio as gr
6
- import traceback
7
  import gc
 
 
 
8
  import numpy as np
9
  import librosa
 
10
  from pydub import AudioSegment
11
  from pydub.effects import normalize
 
 
12
  from huggingface_hub import snapshot_download
13
  from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav
14
 
15
- # Set basic CPU optimization flags
16
- os.environ["OMP_NUM_THREADS"] = str(os.cpu_count())
17
- torch.set_num_threads(os.cpu_count())
 
 
18
 
19
- def download_weights():
20
- """Download model weights from HuggingFace if not already present."""
21
- repo_id = "mrfakename/MegaTTS3-VoiceCloning"
22
- weights_dir = "checkpoints"
23
-
24
- if not os.path.exists(weights_dir):
25
- print("Downloading model weights from HuggingFace...")
26
- snapshot_download(
27
- repo_id=repo_id,
28
- local_dir=weights_dir,
29
- local_dir_use_symlinks=False,
30
- resume_download=True
31
- )
32
- print("Model weights downloaded successfully!")
33
- else:
34
- print("Model weights already exist.")
35
-
36
- return weights_dir
37
 
38
- # Download weights and initialize model
39
- download_weights()
40
- print("Initializing MegaTTS3 model...")
41
- # Force model to use CPU
42
- infer_pipe = MegaTTS3DiTInfer(device="cpu")
43
- print(f"Model loaded successfully on CPU with {os.cpu_count()} threads!")
44
 
45
- def reset_model():
46
- """Reset the inference pipeline"""
47
- global infer_pipe
48
- try:
49
- print("Reinitializing MegaTTS3 model...")
50
- infer_pipe = MegaTTS3DiTInfer(device="cpu")
51
- print("Model reinitialized successfully on CPU!")
52
- return True
53
- except Exception as e:
54
- print(f"Failed to reinitialize model: {e}")
55
- return False
 
 
 
 
56
 
57
- def generate_speech(inp_audio, inp_text, infer_timestep, p_w, t_w, speed_factor):
58
- if not inp_audio or not inp_text:
59
- gr.Warning("Please provide both reference audio and text to generate.")
60
- return None
61
-
62
- try:
63
- print(f"Generating speech with: {inp_text}...")
64
- print(f"Running on CPU with {os.cpu_count()} threads...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- # Robustly preprocess audio
67
  try:
68
- processed_audio_path = preprocess_audio_robust(inp_audio)
69
- # Use existing cut_wav for final trimming
70
- cut_wav(processed_audio_path, max_len=28)
71
- wav_path = processed_audio_path
72
- except Exception as audio_error:
73
- gr.Warning(f"Audio preprocessing failed: {str(audio_error)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  return None
75
-
76
- # Read audio file
77
- with open(wav_path, 'rb') as file:
78
- file_content = file.read()
79
-
80
- # Generate speech with proper error handling
81
  try:
82
- with torch.no_grad(): # Use no_grad for inference
83
- resource_context = infer_pipe.preprocess(file_content)
84
- wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
 
 
85
 
86
- # Apply speed adjustment if needed
87
  if speed_factor != 1.0:
88
- wav_bytes = adjust_speed(wav_bytes, speed_factor)
 
 
 
89
 
90
- # Clean up memory after successful generation
91
- cleanup_memory()
 
 
 
 
 
 
 
 
 
92
  return wav_bytes
93
- except RuntimeError as e:
94
- print(f"Error during inference: {e}")
95
- # Try to reset the model
96
- if reset_model():
97
- gr.Warning("Error occurred. Model has been reset. Please try again.")
98
- else:
99
- gr.Warning("Error occurred and model reset failed. Please restart the application.")
100
- return None
101
-
102
- except Exception as e:
103
- traceback.print_exc()
104
- gr.Warning(f"Speech generation failed: {str(e)}")
105
- # Clean up memory on any error
106
- cleanup_memory()
107
- return None
108
 
109
- def adjust_speed(wav_bytes, speed_factor):
110
- """Adjust the speed of the audio without changing pitch"""
111
- try:
112
- # Create temp file
113
- temp_input = "temp_input.wav"
114
- temp_output = "temp_output.wav"
115
-
116
- with open(temp_input, "wb") as f:
117
- f.write(wav_bytes)
118
-
119
- # Load audio
120
- audio = AudioSegment.from_file(temp_input)
121
-
122
- # Apply speed change
123
- if speed_factor != 1.0:
124
- # Manually adjust frame rate to change speed without pitch alteration
125
- new_frame_rate = int(audio.frame_rate * speed_factor)
126
- audio = audio._spawn(audio.raw_data, overrides={
127
- "frame_rate": new_frame_rate
128
- }).set_frame_rate(audio.frame_rate)
129
-
130
- # Export result
131
- audio.export(temp_output, format="wav")
132
-
133
- # Read and return
134
- with open(temp_output, "rb") as f:
135
- result = f.read()
136
-
137
- # Clean up temp files
138
- os.remove(temp_input)
139
- os.remove(temp_output)
140
 
141
- return result
142
- except Exception as e:
143
- print(f"Speed adjustment failed: {e}")
144
- return wav_bytes # Return original if adjustment fails
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
- def cleanup_memory():
147
- """Clean up system memory."""
148
- gc.collect()
149
- if torch.cuda.is_available():
150
- torch.cuda.empty_cache()
 
151
 
152
- def preprocess_audio_robust(audio_path, target_sr=22050, max_duration=30):
153
- """Robustly preprocess audio"""
154
- try:
155
- # Load with pydub for robust format handling
156
- audio = AudioSegment.from_file(audio_path)
157
-
158
- # Convert to mono if stereo
159
- if audio.channels > 1:
160
- audio = audio.set_channels(1)
161
-
162
- # Limit duration to prevent memory issues
163
- if len(audio) > max_duration * 1000: # pydub uses milliseconds
164
- audio = audio[:max_duration * 1000]
165
-
166
- # Normalize audio to prevent clipping
167
- audio = normalize(audio)
168
-
169
- # Convert to target sample rate
170
- audio = audio.set_frame_rate(target_sr)
171
-
172
- # Export to temporary WAV file with specific parameters
173
- temp_path = audio_path.replace(os.path.splitext(audio_path)[1], '_processed.wav')
174
- audio.export(
175
- temp_path,
176
- format="wav",
177
- parameters=["-acodec", "pcm_s16le", "-ac", "1", "-ar", str(target_sr)]
178
- )
179
-
180
- # Validate the audio with librosa
181
- wav, sr = librosa.load(temp_path, sr=target_sr, mono=True)
182
-
183
- # Check for invalid values
184
- if np.any(np.isnan(wav)) or np.any(np.isinf(wav)):
185
- raise ValueError("Audio contains NaN or infinite values")
186
-
187
- # Ensure reasonable amplitude range
188
- if np.max(np.abs(wav)) < 1e-6:
189
- raise ValueError("Audio signal is too quiet")
190
-
191
- # Re-save the validated audio
192
- import soundfile as sf
193
- sf.write(temp_path, wav, sr)
194
-
195
- return temp_path
196
-
197
- except Exception as e:
198
- print(f"Audio preprocessing failed: {e}")
199
- raise ValueError(f"Failed to process audio: {str(e)}")
200
 
201
- with gr.Blocks(title="MegaTTS3 Voice Cloning") as demo:
202
- with gr.Row():
203
- with gr.Column():
204
- reference_audio = gr.Audio(
205
- label="Reference Audio",
206
- type="filepath",
207
- sources=["upload", "microphone"]
208
- )
209
- text_input = gr.Textbox(
210
- label="Text to Generate",
211
- placeholder="Enter the text you want to synthesize...",
212
- lines=3
213
- )
214
-
215
- with gr.Accordion("Advanced Options", open=False):
216
- infer_timestep = gr.Number(
217
- label="Inference Timesteps",
218
- value=32,
219
- minimum=1,
220
- maximum=100,
221
- step=1
222
  )
223
- p_w = gr.Number(
224
- label="Intelligibility Weight",
225
- value=1.4,
226
- minimum=0.1,
227
- maximum=5.0,
228
- step=0.1
229
- )
230
- t_w = gr.Number(
231
- label="Similarity Weight",
232
- value=3.0,
233
- minimum=0.1,
234
- maximum=10.0,
235
- step=0.1
236
- )
237
- speed_factor = gr.Slider(
238
- label="Speed Adjustment",
239
- value=1.0,
240
- minimum=0.5,
241
- maximum=2.0,
242
- step=0.1,
243
- info="1.0 = normal speed, <1.0 = slower, >1.0 = faster"
244
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
- generate_btn = gr.Button("Generate Speech", variant="primary")
 
 
247
 
248
- with gr.Column():
249
- output_audio = gr.Audio(label="Generated Audio")
 
 
 
250
 
251
- generate_btn.click(
252
- fn=generate_speech,
253
- inputs=[reference_audio, text_input, infer_timestep, p_w, t_w, speed_factor],
254
- outputs=[output_audio]
255
- )
 
 
 
 
 
 
 
256
 
257
  if __name__ == '__main__':
258
- demo.launch(server_name='0.0.0.0', server_port=7860)
 
 
 
 
 
 
 
 
 
1
  import os
 
 
 
2
  import gc
3
+ import torch
4
+ import tempfile
5
+ import traceback
6
  import numpy as np
7
  import librosa
8
+ import gradio as gr
9
  from pydub import AudioSegment
10
  from pydub.effects import normalize
11
+ from concurrent.futures import ThreadPoolExecutor
12
+ from functools import partial
13
  from huggingface_hub import snapshot_download
14
  from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav
15
 
16
+ # Cấu hình tối ưu CPU
17
+ os.environ["OMP_NUM_THREADS"] = str(os.cpu_count() or 4)
18
+ os.environ["MKL_NUM_THREADS"] = str(os.cpu_count() or 4)
19
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
20
+ torch.set_num_threads(os.cpu_count() or 4)
21
 
22
+ # Bộ nhớ đệm
23
+ AUDIO_CACHE = {}
24
+ MODEL_CACHE = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ class TTSEngine:
27
+ def __init__(self):
28
+ self.model = None
29
+ self.weights_dir = "checkpoints"
30
+ self.initialize_model()
 
31
 
32
+ def download_weights(self):
33
+ """Tải trọng số model nếu chưa có"""
34
+ repo_id = "mrfakename/MegaTTS3-VoiceCloning"
35
+
36
+ if not os.path.exists(self.weights_dir):
37
+ print("Đang tải trọng số model từ HuggingFace...")
38
+ snapshot_download(
39
+ repo_id=repo_id,
40
+ local_dir=self.weights_dir,
41
+ local_dir_use_symlinks=False,
42
+ resume_download=True
43
+ )
44
+ print("Đã tải xong trọng số model!")
45
+ else:
46
+ print("Trọng số model đã tồn tại.")
47
 
48
+ def initialize_model(self):
49
+ """Khởi tạo model TTS"""
50
+ self.download_weights()
51
+ print("Đang khởi tạo model MegaTTS3...")
52
+ self.model = MegaTTS3DiTInfer(device="cpu")
53
+ print(f"Model đã được tải thành công trên CPU với {os.cpu_count()} luồng!")
54
+
55
+ def reset_model(self):
56
+ """Khởi tạo lại model"""
57
+ try:
58
+ print("Đang khởi tạo lại model...")
59
+ self.model = MegaTTS3DiTInfer(device="cpu")
60
+ print("Đã khởi tạo lại model thành công!")
61
+ return True
62
+ except Exception as e:
63
+ print(f"Không thể khởi tạo lại model: {e}")
64
+ return False
65
+
66
+ def preprocess_audio(self, audio_path, target_sr=22050, max_duration=30):
67
+ """Tiền xử lý audio đầu vào"""
68
+ cache_key = f"preprocessed_{hash(audio_path)}"
69
+ if cache_key in AUDIO_CACHE:
70
+ return AUDIO_CACHE[cache_key]
71
 
 
72
  try:
73
+ audio = AudioSegment.from_file(audio_path)
74
+ audio = audio.set_channels(1).set_frame_rate(target_sr)
75
+
76
+ if len(audio) > max_duration * 1000:
77
+ audio = audio[:max_duration * 1000]
78
+
79
+ audio = normalize(audio)
80
+
81
+ temp_path = f"temp_{os.path.basename(audio_path)}"
82
+ audio.export(
83
+ temp_path,
84
+ format="wav",
85
+ parameters=["-acodec", "pcm_s16le", "-ac", "1", "-ar", str(target_sr)]
86
+ )
87
+
88
+ # Xác thực chất lượng audio
89
+ wav, sr = librosa.load(temp_path, sr=target_sr, mono=True)
90
+ if np.any(np.isnan(wav)) or np.any(np.isinf(wav)):
91
+ raise ValueError("Audio chứa giá trị không hợp lệ")
92
+
93
+ if np.max(np.abs(wav)) < 1e-6:
94
+ raise ValueError("Tín hiệu audio quá yếu")
95
+
96
+ import soundfile as sf
97
+ sf.write(temp_path, wav, sr)
98
+
99
+ AUDIO_CACHE[cache_key] = temp_path
100
+ return temp_path
101
+
102
+ except Exception as e:
103
+ print(f"Lỗi tiền xử lý audio: {e}")
104
+ raise ValueError(f"Lỗi khi xử lý audio: {str(e)}")
105
+
106
+ def process_sentence(self, audio_context, sentence, params):
107
+ """Xử lý một câu đơn lẻ"""
108
+ try:
109
+ with torch.no_grad():
110
+ wav_bytes = self.model.forward(
111
+ audio_context,
112
+ sentence,
113
+ time_step=params['infer_timestep'],
114
+ p_w=params['p_w'],
115
+ t_w=params['t_w']
116
+ )
117
+
118
+ if params['speed_factor'] != 1.0:
119
+ wav_bytes = self.adjust_speed(wav_bytes, params['speed_factor'])
120
+
121
+ return wav_bytes
122
+ except Exception as e:
123
+ print(f"Lỗi khi xử lý câu: {sentence[:50]}... - {str(e)}")
124
  return None
125
+
126
+ def adjust_speed(self, wav_bytes, speed_factor):
127
+ """Điều chỉnh tốc độ âm thanh"""
 
 
 
128
  try:
129
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_input:
130
+ temp_input.write(wav_bytes)
131
+ temp_input_path = temp_input.name
132
+
133
+ audio = AudioSegment.from_file(temp_input_path)
134
 
 
135
  if speed_factor != 1.0:
136
+ new_frame_rate = int(audio.frame_rate * speed_factor)
137
+ audio = audio._spawn(audio.raw_data, overrides={
138
+ "frame_rate": new_frame_rate
139
+ }).set_frame_rate(audio.frame_rate)
140
 
141
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_output:
142
+ audio.export(temp_output.name, format="wav")
143
+ with open(temp_output.name, "rb") as f:
144
+ result = f.read()
145
+
146
+ os.unlink(temp_input_path)
147
+ os.unlink(temp_output.name)
148
+
149
+ return result
150
+ except Exception as e:
151
+ print(f"Lỗi điều chỉnh tốc độ: {e}")
152
  return wav_bytes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
+ def generate_speech(self, inp_audio, inp_text, params):
155
+ """Tạo giọng nói từ văn bản"""
156
+ if not inp_audio or not inp_text:
157
+ gr.Warning("Vui lòng cung cấp cả audio tham chiếu và văn bản cần chuyển đổi.")
158
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
+ try:
161
+ print(f"Đang tạo giọng nói cho văn bản dài {len(inp_text)} ký tự...")
162
+
163
+ # Xử audio đầu vào với bộ nhớ đệm
164
+ cache_key = f"audio_{hash(inp_audio)}"
165
+ if cache_key not in AUDIO_CACHE:
166
+ processed_audio_path = self.preprocess_audio(inp_audio)
167
+ cut_wav(processed_audio_path, max_len=28)
168
+
169
+ with open(processed_audio_path, 'rb') as file:
170
+ file_content = file.read()
171
+
172
+ audio_context = self.model.preprocess(file_content)
173
+ AUDIO_CACHE[cache_key] = audio_context
174
+ else:
175
+ audio_context = AUDIO_CACHE[cache_key]
176
+ print("Đã sử dụng audio từ bộ nhớ đệm")
177
+
178
+ # Chia văn bản thành các câu
179
+ sentences = [s.strip() for s in inp_text.split('.') if s.strip()]
180
+
181
+ if not sentences:
182
+ gr.Warning("Không tìm thấy câu nào trong văn bản")
183
+ return None
184
+
185
+ # Xử lý song song các câu
186
+ with ThreadPoolExecutor(max_workers=min(4, len(sentences))) as executor:
187
+ process_fn = partial(self.process_sentence, audio_context, params=params)
188
+ results = list(executor.map(process_fn, sentences))
189
+
190
+ # Ghép các đoạn âm thanh lại
191
+ combined_audio = None
192
+ for result in results:
193
+ if result is None:
194
+ continue
195
+
196
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
197
+ temp_file.write(result)
198
+ temp_path = temp_file.name
199
+
200
+ segment = AudioSegment.from_file(temp_path)
201
+ os.unlink(temp_path)
202
+
203
+ if combined_audio is None:
204
+ combined_audio = segment
205
+ else:
206
+ combined_audio += AudioSegment.silent(duration=200) # Thêm khoảng nghỉ 200ms giữa các câu
207
+ combined_audio += segment
208
+
209
+ if combined_audio is None:
210
+ gr.Warning("Không thể tạo bất kỳ đoạn âm thanh nào")
211
+ return None
212
+
213
+ # Xuất file kết quả
214
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as output_file:
215
+ combined_audio.export(output_file.name, format="wav")
216
+ with open(output_file.name, "rb") as f:
217
+ final_result = f.read()
218
+ os.unlink(output_file.name)
219
+
220
+ self.cleanup_memory()
221
+ return final_result
222
+
223
+ except Exception as e:
224
+ traceback.print_exc()
225
+ gr.Warning(f"Lỗi khi tạo giọng nói: {str(e)}")
226
+ self.cleanup_memory()
227
+ return None
228
 
229
+ def cleanup_memory(self):
230
+ """Dọn dẹp bộ nhớ"""
231
+ gc.collect()
232
+ if torch.cuda.is_available():
233
+ torch.cuda.empty_cache()
234
+ AUDIO_CACHE.clear()
235
 
236
+ # Khởi tạo engine TTS
237
+ tts_engine = TTSEngine()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
+ # Giao diện Gradio
240
+ def create_gradio_interface():
241
+ with gr.Blocks(title="MegaTTS3 - Chuyển văn bản thành giọng nói") as demo:
242
+ with gr.Row():
243
+ with gr.Column():
244
+ reference_audio = gr.Audio(
245
+ label="Audio tham chiếu",
246
+ type="filepath",
247
+ sources=["upload", "microphone"]
 
 
 
 
 
 
 
 
 
 
 
 
248
  )
249
+ text_input = gr.Textbox(
250
+ label="Văn bản cần chuyển đổi",
251
+ placeholder="Nhập văn bản bạn muốn chuyển thành giọng nói...",
252
+ lines=5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  )
254
+
255
+ with gr.Accordion("Tùy chọn nâng cao", open=False):
256
+ infer_timestep = gr.Slider(
257
+ label="Số bước suy luận",
258
+ value=32,
259
+ minimum=1,
260
+ maximum=100,
261
+ step=1
262
+ )
263
+ p_w = gr.Slider(
264
+ label="Trọng số rõ ràng",
265
+ value=1.4,
266
+ minimum=0.1,
267
+ maximum=5.0,
268
+ step=0.1
269
+ )
270
+ t_w = gr.Slider(
271
+ label="Trọng số tương đồng",
272
+ value=3.0,
273
+ minimum=0.1,
274
+ maximum=10.0,
275
+ step=0.1
276
+ )
277
+ speed_factor = gr.Slider(
278
+ label="Tốc độ phát",
279
+ value=1.0,
280
+ minimum=0.5,
281
+ maximum=2.0,
282
+ step=0.1,
283
+ info="1.0 = bình thường, <1.0 = chậm hơn, >1.0 = nhanh hơn"
284
+ )
285
+
286
+ generate_btn = gr.Button("Tạo giọng nói", variant="primary")
287
 
288
+ with gr.Column():
289
+ output_audio = gr.Audio(label="Kết quả âm thanh")
290
+ status = gr.Textbox(label="Trạng thái")
291
 
292
+ generate_btn.click(
293
+ fn=generate_speech_wrapper,
294
+ inputs=[reference_audio, text_input, infer_timestep, p_w, t_w, speed_factor],
295
+ outputs=[output_audio, status]
296
+ )
297
 
298
+ return demo
299
+
300
+ def generate_speech_wrapper(audio, text, timestep, p_w, t_w, speed):
301
+ params = {
302
+ 'infer_timestep': timestep,
303
+ 'p_w': p_w,
304
+ 't_w': t_w,
305
+ 'speed_factor': speed
306
+ }
307
+ result = tts_engine.generate_speech(audio, text, params)
308
+ status = "Hoàn thành!" if result else "Đã xảy ra lỗi!"
309
+ return result, status
310
 
311
  if __name__ == '__main__':
312
+ demo = create_gradio_interface()
313
+ demo.launch(
314
+ server_name='0.0.0.0',
315
+ server_port=7860,
316
+ share=False,
317
+ show_error=True
318
+ )