|
import json |
|
import torch |
|
import gradio as gr |
|
from torch import nn |
|
import soundfile as sf |
|
|
|
|
|
class TTSModel(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
|
|
|
|
self.config = config |
|
|
|
|
|
def forward(self, inputs): |
|
|
|
|
|
return torch.zeros(22050) |
|
|
|
|
|
def preprocess_text(text): |
|
|
|
return text |
|
|
|
|
|
def text_to_speech(text): |
|
inputs = preprocess_text(text) |
|
|
|
with torch.no_grad(): |
|
audio_output = model(inputs) |
|
print(type(audio_output)) |
|
print(audio_output.shape) |
|
print(audio_output[:10]) |
|
|
|
|
|
return audio_output.numpy() |
|
|
|
|
|
def load_models(): |
|
try: |
|
duration_net = DurationNet(hps.data.vocab_size, 64, 4).to(device) |
|
duration_net.load_state_dict(torch.load(duration_model_path, map_location=device)) |
|
duration_net.eval() |
|
print("DurationNet loaded successfully.") |
|
except Exception as e: |
|
print(f"Error loading DurationNet: {e}") |
|
return None, None |
|
|
|
try: |
|
generator = SynthesizerTrn( |
|
hps.data.vocab_size, |
|
hps.data.filter_length // 2 + 1, |
|
hps.train.segment_size // hps.data.hop_length, |
|
**vars(hps.model), |
|
).to(device) |
|
|
|
del generator.enc_q |
|
ckpt = torch.load(lightspeed_model_path, map_location=device) |
|
params = {} |
|
|
|
for k, v in ckpt["net_g"].items(): |
|
k = k[7:] if k.startswith("module.") else k |
|
params[k] = v |
|
|
|
generator.load_state_dict(params, strict=False) |
|
generator.eval() |
|
print("SynthesizerTrn loaded successfully.") |
|
except Exception as e: |
|
print(f"Error loading SynthesizerTrn: {e}") |
|
return None, None |
|
|
|
return duration_net, generator |
|
|
|
|
|
|
|
|
|
|
|
config_path = "config.json" |
|
duration_model_path = "vbx_duration_model.pth" |
|
generation_model_path = "gen_619k.pth" |
|
|
|
|
|
with open(config_path, 'r') as f: |
|
config = json.load(f) |
|
|
|
|
|
model = TTSModel(config) |
|
model.eval() |
|
|
|
|
|
model.load_state_dict(torch.load(duration_model_path, map_location=torch.device('cpu'), weights_only=True), strict=False) |
|
model.load_state_dict(torch.load(generation_model_path, map_location=torch.device('cpu'), weights_only=True), strict=False) |
|
|
|
|
|
|
|
def infer(text): |
|
audio = text_to_speech(text) |
|
sf.write('output.wav', audio, 22050) |
|
return 'output.wav' |
|
|
|
iface = gr.Interface(fn=infer, inputs="text", outputs="audio", title="Text to Speech", |
|
description="Chuyển đổi văn bản tiếng Việt thành giọng nói.") |
|
|
|
iface.launch() |
|
|