|
from transformers import VitsModel, AutoTokenizer |
|
import torch |
|
import scipy.io.wavfile |
|
import util |
|
|
|
|
|
model_id = "facebook/mms-tts-uig-script_arabic" |
|
tts_tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
tts_model = VitsModel.from_pretrained(model_id) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
tts_model = tts_model.to(device) |
|
|
|
def generate_audio(input_text, script): |
|
""" |
|
Generate audio for the given input text and script |
|
""" |
|
|
|
if script != "Uyghur Arabic": |
|
input_text = util.ug_latn_to_arab(input_text) |
|
|
|
|
|
tts_inputs = tts_tokenizer(input_text, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
tts_output = tts_model(**tts_inputs).waveform.cpu() |
|
|
|
|
|
output_path = "tts_output.wav" |
|
sample_rate = 16000 |
|
scipy.io.wavfile.write(output_path, rate=sample_rate, data=tts_output.numpy()[0]) |
|
|
|
|
|
return output_path |