vuxuanhoan's picture
Update app.py
65fb3de verified
import json
import torch
import gradio as gr
from torch import nn
import soundfile as sf
# Định nghĩa lớp mô hình TTS
class TTSModel(nn.Module):
def __init__(self, config):
super().__init__()
# Khởi tạo các lớp của mô hình TTS ở đây
# Giả định bạn đã xác định cách xây dựng mô hình dựa trên cấu hình
self.config = config
# Các thành phần khác của mô hình sẽ được thêm vào đây.
def forward(self, inputs):
# Logic của mô hình để chuyển đổi văn bản thành giọng nói
# Giả định rằng bạn đã thực hiện điều này
return torch.zeros(22050) # Trả về một tensor âm thanh giả định
# Hàm tiền xử lý văn bản
def preprocess_text(text):
# Chuyển đổi văn bản thành dạng số (encoding)
return text # Đây chỉ là một ví dụ đơn giản
# Hàm chuyển đổi văn bản thành giọng nói
def text_to_speech(text):
inputs = preprocess_text(text)
with torch.no_grad():
audio_output = model(inputs)
print(type(audio_output)) # In kiểu dữ liệu
print(audio_output.shape) # In kích thước đầu ra
print(audio_output[:10]) # In 10 giá trị đầu tiên của âm thanh
# Giả sử rằng audio_output là một tensor âm thanh
return audio_output.numpy()
def load_models():
try:
duration_net = DurationNet(hps.data.vocab_size, 64, 4).to(device)
duration_net.load_state_dict(torch.load(duration_model_path, map_location=device))
duration_net.eval()
print("DurationNet loaded successfully.")
except Exception as e:
print(f"Error loading DurationNet: {e}")
return None, None
try:
generator = SynthesizerTrn(
hps.data.vocab_size,
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**vars(hps.model),
).to(device)
del generator.enc_q
ckpt = torch.load(lightspeed_model_path, map_location=device)
params = {}
for k, v in ckpt["net_g"].items():
k = k[7:] if k.startswith("module.") else k
params[k] = v
generator.load_state_dict(params, strict=False)
generator.eval()
print("SynthesizerTrn loaded successfully.")
except Exception as e:
print(f"Error loading SynthesizerTrn: {e}")
return None, None
return duration_net, generator
# Tải cấu hình và trọng số mô hình
config_path = "config.json"
duration_model_path = "vbx_duration_model.pth"
generation_model_path = "gen_619k.pth"
# Tải cấu hình
with open(config_path, 'r') as f:
config = json.load(f)
# Tạo mô hình
model = TTSModel(config)
model.eval() # Chuyển mô hình về chế độ đánh giá
# Tải trọng số mô hình
model.load_state_dict(torch.load(duration_model_path, map_location=torch.device('cpu'), weights_only=True), strict=False)
model.load_state_dict(torch.load(generation_model_path, map_location=torch.device('cpu'), weights_only=True), strict=False)
# Xây dựng giao diện Gradio
def infer(text):
audio = text_to_speech(text)
sf.write('output.wav', audio, 22050) # Lưu âm thanh vào tệp WAV
return 'output.wav'
iface = gr.Interface(fn=infer, inputs="text", outputs="audio", title="Text to Speech",
description="Chuyển đổi văn bản tiếng Việt thành giọng nói.")
iface.launch()