mrfakename commited on
Commit
1ca3adb
·
1 Parent(s): cb4e009

fix stability

Browse files
Files changed (4) hide show
  1. app.py +113 -8
  2. packages.txt +1 -0
  3. requirements.txt +2 -1
  4. tts/frontend_function.py +20 -1
app.py CHANGED
@@ -4,6 +4,11 @@ import os
4
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
5
  import gradio as gr
6
  import traceback
 
 
 
 
 
7
  from huggingface_hub import snapshot_download
8
  from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav
9
 
@@ -33,6 +38,21 @@ print("Initializing MegaTTS3 model...")
33
  infer_pipe = MegaTTS3DiTInfer()
34
  print("Model loaded successfully!")
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  @spaces.GPU
37
  def generate_speech(inp_audio, inp_text, infer_timestep, p_w, t_w):
38
  if not inp_audio or not inp_text:
@@ -42,25 +62,110 @@ def generate_speech(inp_audio, inp_text, infer_timestep, p_w, t_w):
42
  try:
43
  print(f"Generating speech with: {inp_text}...")
44
 
45
- # Convert and prepare audio
46
- convert_to_wav(inp_audio)
47
- wav_path = os.path.splitext(inp_audio)[0] + '.wav'
48
- cut_wav(wav_path, max_len=28)
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  # Read audio file
51
  with open(wav_path, 'rb') as file:
52
  file_content = file.read()
53
 
54
- # Generate speech
55
- resource_context = infer_pipe.preprocess(file_content)
56
- wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- return wav_bytes
59
  except Exception as e:
60
  traceback.print_exc()
61
  gr.Warning(f"Speech generation failed: {str(e)}")
 
 
62
  return None
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  with gr.Blocks(title="MegaTTS3 Voice Cloning") as demo:
66
  gr.Markdown("# MegaTTS 3 Voice Cloning")
 
4
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
5
  import gradio as gr
6
  import traceback
7
+ import gc
8
+ import numpy as np
9
+ import librosa
10
+ from pydub import AudioSegment
11
+ from pydub.effects import normalize
12
  from huggingface_hub import snapshot_download
13
  from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav
14
 
 
38
  infer_pipe = MegaTTS3DiTInfer()
39
  print("Model loaded successfully!")
40
 
41
+ def reset_model():
42
+ """Reset the inference pipeline to recover from CUDA errors."""
43
+ global infer_pipe
44
+ try:
45
+ if torch.cuda.is_available():
46
+ torch.cuda.empty_cache()
47
+ torch.cuda.synchronize()
48
+ print("Reinitializing MegaTTS3 model...")
49
+ infer_pipe = MegaTTS3DiTInfer()
50
+ print("Model reinitialized successfully!")
51
+ return True
52
+ except Exception as e:
53
+ print(f"Failed to reinitialize model: {e}")
54
+ return False
55
+
56
  @spaces.GPU
57
  def generate_speech(inp_audio, inp_text, infer_timestep, p_w, t_w):
58
  if not inp_audio or not inp_text:
 
62
  try:
63
  print(f"Generating speech with: {inp_text}...")
64
 
65
+ # Check CUDA availability and clear cache
66
+ if torch.cuda.is_available():
67
+ torch.cuda.empty_cache()
68
+ print(f"CUDA device: {torch.cuda.get_device_name()}")
69
+ else:
70
+ gr.Warning("CUDA is not available. Please check your GPU setup.")
71
+ return None
72
+
73
+ # Robustly preprocess audio
74
+ try:
75
+ processed_audio_path = preprocess_audio_robust(inp_audio)
76
+ # Use existing cut_wav for final trimming
77
+ cut_wav(processed_audio_path, max_len=28)
78
+ wav_path = processed_audio_path
79
+ except Exception as audio_error:
80
+ gr.Warning(f"Audio preprocessing failed: {str(audio_error)}")
81
+ return None
82
 
83
  # Read audio file
84
  with open(wav_path, 'rb') as file:
85
  file_content = file.read()
86
 
87
+ # Generate speech with proper error handling
88
+ try:
89
+ resource_context = infer_pipe.preprocess(file_content)
90
+ wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
91
+ # Clean up memory after successful generation
92
+ cleanup_memory()
93
+ return wav_bytes
94
+ except RuntimeError as cuda_error:
95
+ if "CUDA" in str(cuda_error):
96
+ print(f"CUDA error detected: {cuda_error}")
97
+ # Try to reset the model to recover from CUDA errors
98
+ if reset_model():
99
+ gr.Warning("CUDA error occurred. Model has been reset. Please try again.")
100
+ else:
101
+ gr.Warning("CUDA error occurred and model reset failed. Please restart the application.")
102
+ return None
103
+ else:
104
+ raise cuda_error
105
 
 
106
  except Exception as e:
107
  traceback.print_exc()
108
  gr.Warning(f"Speech generation failed: {str(e)}")
109
+ # Clean up CUDA memory on any error
110
+ cleanup_memory()
111
  return None
112
 
113
+ def cleanup_memory():
114
+ """Clean up GPU and system memory."""
115
+ gc.collect()
116
+ if torch.cuda.is_available():
117
+ torch.cuda.empty_cache()
118
+ torch.cuda.synchronize()
119
+
120
+ def preprocess_audio_robust(audio_path, target_sr=22050, max_duration=30):
121
+ """Robustly preprocess audio to prevent CUDA errors."""
122
+ try:
123
+ # Load with pydub for robust format handling
124
+ audio = AudioSegment.from_file(audio_path)
125
+
126
+ # Convert to mono if stereo
127
+ if audio.channels > 1:
128
+ audio = audio.set_channels(1)
129
+
130
+ # Limit duration to prevent memory issues
131
+ if len(audio) > max_duration * 1000: # pydub uses milliseconds
132
+ audio = audio[:max_duration * 1000]
133
+
134
+ # Normalize audio to prevent clipping
135
+ audio = normalize(audio)
136
+
137
+ # Convert to target sample rate
138
+ audio = audio.set_frame_rate(target_sr)
139
+
140
+ # Export to temporary WAV file with specific parameters
141
+ temp_path = audio_path.replace(os.path.splitext(audio_path)[1], '_processed.wav')
142
+ audio.export(
143
+ temp_path,
144
+ format="wav",
145
+ parameters=["-acodec", "pcm_s16le", "-ac", "1", "-ar", str(target_sr)]
146
+ )
147
+
148
+ # Validate the audio with librosa
149
+ wav, sr = librosa.load(temp_path, sr=target_sr, mono=True)
150
+
151
+ # Check for invalid values
152
+ if np.any(np.isnan(wav)) or np.any(np.isinf(wav)):
153
+ raise ValueError("Audio contains NaN or infinite values")
154
+
155
+ # Ensure reasonable amplitude range
156
+ if np.max(np.abs(wav)) < 1e-6:
157
+ raise ValueError("Audio signal is too quiet")
158
+
159
+ # Re-save the validated audio
160
+ import soundfile as sf
161
+ sf.write(temp_path, wav, sr)
162
+
163
+ return temp_path
164
+
165
+ except Exception as e:
166
+ print(f"Audio preprocessing failed: {e}")
167
+ raise ValueError(f"Failed to process audio: {str(e)}")
168
+
169
 
170
  with gr.Blocks(title="MegaTTS3 Voice Cloning") as demo:
171
  gr.Markdown("# MegaTTS 3 Voice Cloning")
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
requirements.txt CHANGED
@@ -16,4 +16,5 @@ torchdiffeq==0.2.5
16
  openai-whisper==20240930
17
  httpx==0.28.1
18
  gradio==5.23.1
19
- hf-transfer
 
 
16
  openai-whisper==20240930
17
  httpx==0.28.1
18
  gradio==5.23.1
19
+ hf-transfer
20
+ soundfile
tts/frontend_function.py CHANGED
@@ -16,6 +16,7 @@ import torch
16
  import torch.nn.functional as F
17
  import whisper
18
  import librosa
 
19
  from copy import deepcopy
20
  from tts.utils.text_utils.ph_tone_convert import split_ph_timestamp, split_ph
21
  from tts.utils.audio_utils.align import mel2token_to_dur
@@ -39,8 +40,26 @@ def g2p(self, text_inp):
39
  ''' Get phoneme2mel align of prompt speech '''
40
  def align(self, wav):
41
  with torch.inference_mode():
 
 
 
 
42
  whisper_wav = librosa.resample(wav, orig_sr=self.sr, target_sr=16000)
43
- mel = torch.FloatTensor(whisper.log_mel_spectrogram(whisper_wav).T).to(self.device)[None].transpose(1,2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  prompt_max_frame = mel.size(2) // self.fm * self.fm
45
  mel = mel[:, :, :prompt_max_frame]
46
  token = torch.LongTensor([[798]]).to(self.device)
 
16
  import torch.nn.functional as F
17
  import whisper
18
  import librosa
19
+ import numpy as np
20
  from copy import deepcopy
21
  from tts.utils.text_utils.ph_tone_convert import split_ph_timestamp, split_ph
22
  from tts.utils.audio_utils.align import mel2token_to_dur
 
40
  ''' Get phoneme2mel align of prompt speech '''
41
  def align(self, wav):
42
  with torch.inference_mode():
43
+ # Validate input audio
44
+ if np.any(np.isnan(wav)) or np.any(np.isinf(wav)):
45
+ raise ValueError("Input audio contains NaN or infinite values")
46
+
47
  whisper_wav = librosa.resample(wav, orig_sr=self.sr, target_sr=16000)
48
+
49
+ # Validate resampled audio
50
+ if np.any(np.isnan(whisper_wav)) or np.any(np.isinf(whisper_wav)):
51
+ raise ValueError("Resampled audio contains NaN or infinite values")
52
+
53
+ # Get mel spectrogram with validation
54
+ mel_spec = whisper.log_mel_spectrogram(whisper_wav)
55
+ if np.any(np.isnan(mel_spec)) or np.any(np.isinf(mel_spec)):
56
+ raise ValueError("Mel spectrogram contains NaN or infinite values")
57
+
58
+ mel = torch.FloatTensor(mel_spec.T).to(self.device)[None].transpose(1,2)
59
+
60
+ # Validate tensor before further processing
61
+ if torch.any(torch.isnan(mel)) or torch.any(torch.isinf(mel)):
62
+ raise ValueError("Mel tensor contains NaN or infinite values")
63
  prompt_max_frame = mel.size(2) // self.fm * self.fm
64
  mel = mel[:, :, :prompt_max_frame]
65
  token = torch.LongTensor([[798]]).to(self.device)