Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import sys | |
import numpy as np | |
import torch | |
import soundfile as sf | |
import spaces | |
from config import models_path, results_path, sample_path, BASE_DIR | |
encoder = None | |
synthesizer = None | |
vocoder = None | |
def load_models(): | |
global encoder, synthesizer, vocoder | |
try: | |
sys.path.append(os.path.join(BASE_DIR, 'pmt2')) | |
from encoder import inference as encoder_module | |
from synthesizer.inference import Synthesizer | |
from parallel_wavegan.utils import load_model as vocoder_hifigan | |
global encoder | |
encoder = encoder_module | |
print("Loading encoder model...") | |
encoder.load_model(os.path.join(models_path, 'encoder.pt')) | |
print("Loading synthesizer model...") | |
synthesizer = Synthesizer(os.path.join(models_path, 'synthesizer.pt')) | |
print("Loading HiFiGAN vocoder...") | |
vocoder = vocoder_hifigan(os.path.join(models_path, 'vocoder_HiFiGAN.pkl')) | |
vocoder.remove_weight_norm() | |
vocoder = vocoder.eval().to('cuda' if torch.cuda.is_available() else 'cpu') | |
return True | |
except Exception as e: | |
import traceback | |
print(f"Error loading models: {traceback.format_exc()}") | |
return False | |
def generate_speech(text, reference_audio=None): | |
if not text or text.strip() == "": | |
return None | |
try: | |
if reference_audio is None: | |
ref_wav_path = sample_path | |
else: | |
ref_wav_path = os.path.join(results_path, "reference_audio.wav") | |
sf.write(ref_wav_path, reference_audio[1], reference_audio[0]) | |
print(f"Using reference audio: {ref_wav_path}") | |
wav = synthesizer.load_preprocess_wav(ref_wav_path) | |
encoder_wav = encoder.preprocess_wav(wav) | |
embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True) | |
texts = [text] | |
embeds = [embed] * len(texts) | |
specs = synthesizer.synthesize_spectrograms(texts, embeds) | |
spec = np.concatenate(specs, axis=1) | |
x = torch.from_numpy(spec.T).to('cuda' if torch.cuda.is_available() else 'cpu') | |
with torch.no_grad(): | |
wav = vocoder.inference(x) | |
wav = wav.cpu().numpy() | |
wav = wav / np.abs(wav).max() * 0.97 | |
output_filename = f"generated_{hash(text) % 10000}.wav" | |
output_path = os.path.join(results_path, output_filename) | |
sf.write(output_path, wav, synthesizer.sample_rate) | |
return output_path | |
except Exception as e: | |
import traceback | |
error_details = traceback.format_exc() | |
print(f"Error generating speech: {error_details}") | |
return None | |