File size: 1,033 Bytes
bef8623
 
30e5da4
1dfec92
bef8623
 
 
 
 
 
1dfec92
bef8623
 
30e5da4
bef8623
 
1dfec92
 
bef8623
30e5da4
 
4c14db4
bef8623
8377a77
bef8623
30e5da4
a651122
30e5da4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from transformers import VitsModel, AutoTokenizer
import torch
import scipy.io.wavfile
import util

# Load processor and model
models_info = {
    "Meta-MMS": {
        "processor": AutoTokenizer.from_pretrained("facebook/mms-tts-uig-script_arabic"),
        "model": VitsModel.from_pretrained("facebook/mms-tts-uig-script_arabic"),
        "arabic_script": True
    },
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def synthesize(text, model_id):
    if models_info[model_id]["arabic_script"]:
        text = util.ug_latn_to_arab(text)
    processor = models_info[model_id]["processor"]
    model = models_info[model_id]["model"].to(device)
    inputs = processor(text, return_tensors="pt").to(device)

    with torch.no_grad():
        output = model(**inputs).waveform.cpu()  # Move output back to CPU for saving
    
    output_path = "tts_output.wav"
    sample_rate = model.config.sampling_rate
    scipy.io.wavfile.write(output_path, rate=sample_rate, data=output.numpy()[0])

    return output_path