import os import sys import numpy as np import torch import soundfile as sf import spaces from config import models_path, results_path, sample_path, BASE_DIR encoder = None synthesizer = None vocoder = None def load_models(): global encoder, synthesizer, vocoder try: sys.path.append(os.path.join(BASE_DIR, 'pmt2')) from encoder import inference as encoder_module from synthesizer.inference import Synthesizer from parallel_wavegan.utils import load_model as vocoder_hifigan global encoder encoder = encoder_module print("Loading encoder model...") encoder.load_model(os.path.join(models_path, 'encoder.pt')) print("Loading synthesizer model...") synthesizer = Synthesizer(os.path.join(models_path, 'synthesizer.pt')) print("Loading HiFiGAN vocoder...") vocoder = vocoder_hifigan(os.path.join(models_path, 'vocoder_HiFiGAN.pkl')) vocoder.remove_weight_norm() vocoder = vocoder.eval().to('cuda' if torch.cuda.is_available() else 'cpu') return True except Exception as e: import traceback print(f"Error loading models: {traceback.format_exc()}") return False @spaces.GPU(duration=120) def generate_speech(text, reference_audio=None): if not text or text.strip() == "": return None try: if reference_audio is None: ref_wav_path = sample_path else: ref_wav_path = os.path.join(results_path, "reference_audio.wav") sf.write(ref_wav_path, reference_audio[1], reference_audio[0]) print(f"Using reference audio: {ref_wav_path}") wav = synthesizer.load_preprocess_wav(ref_wav_path) encoder_wav = encoder.preprocess_wav(wav) embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True) texts = [text] embeds = [embed] * len(texts) specs = synthesizer.synthesize_spectrograms(texts, embeds) spec = np.concatenate(specs, axis=1) x = torch.from_numpy(spec.T).to('cuda' if torch.cuda.is_available() else 'cpu') with torch.no_grad(): wav = vocoder.inference(x) wav = wav.cpu().numpy() wav = wav / np.abs(wav).max() * 0.97 output_filename = f"generated_{hash(text) % 10000}.wav" output_path = os.path.join(results_path, output_filename) sf.write(output_path, wav, synthesizer.sample_rate) return output_path except Exception as e: import traceback error_details = traceback.format_exc() print(f"Error generating speech: {error_details}") return None