File size: 2,499 Bytes
bef8623
 
30e5da4
4f70bd6
 
 
1dfec92
bef8623
00a9c71
4f70bd6
bef8623
 
 
 
 
1dfec92
bef8623
4f70bd6
bef8623
4f70bd6
 
 
 
 
 
1417583
 
4f70bd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bef8623
 
4f70bd6
 
 
1dfec92
 
bef8623
30e5da4
 
4c14db4
bef8623
e2dd467
bef8623
30e5da4
a651122
e2dd467
30e5da4
4f70bd6
 
 
e2dd467
 
4f70bd6
 
 
 
 
 
8bef169
 
4f70bd6
 
8bef169
00a9c71
4dcae01
4f70bd6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from transformers import VitsModel, AutoTokenizer
import torch
import scipy.io.wavfile
from parallel_wavegan.utils import load_model
from espnet2.bin.tts_inference import Text2Speech
from turkicTTS_utils import normalization
import util

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load processor and model
models_info = {
    "Meta-MMS": {
        "processor": AutoTokenizer.from_pretrained("facebook/mms-tts-uig-script_arabic"),
        "model": VitsModel.from_pretrained("facebook/mms-tts-uig-script_arabic"),
        "arabic_script": True
    },
    "IS2AI-TurkicTTS": None
}

vocoder_checkpoint="parallelwavegan_male2_checkpoint/checkpoint-400000steps.pkl" ### specify vocoder path
vocoder = load_model(vocoder_checkpoint).to(device).eval()
vocoder.remove_weight_norm()

### specify path to the main model(transformer/tacotron2/fastspeech) and its config file
config_file = "exp/tts_train_raw_char/config.yaml"
model_path = "exp/tts_train_raw_char/train.loss.ave_5best.pth"

text2speech = Text2Speech(
    config_file,
    model_path,
    device=device, ## if cuda not available use cpu
    ### only for Tacotron 2
    threshold=0.5,
    minlenratio=0.0,
    maxlenratio=10.0,
    use_att_constraint=True,
    backward_window=1,
    forward_window=3,
    ### only for FastSpeech & FastSpeech2
    speed_control_alpha=1.0,
)
text2speech.spc2wav = None  ### disable griffin-lim

def synthesize(text, model_id):
    if model_id == 'IS2AI-TurkicTTS':
        return synthesize_turkic_tts(text)
    
    if models_info[model_id]["arabic_script"]:
        text = util.ug_latn_to_arab(text)
    processor = models_info[model_id]["processor"]
    model = models_info[model_id]["model"].to(device)
    inputs = processor(text, return_tensors="pt").to(device)

    with torch.no_grad():
        output = model(**inputs).waveform.cpu().numpy()[0]  # Move output back to CPU for saving
    
    output_path = "tts_output.wav"
    sample_rate = model.config.sampling_rate
    scipy.io.wavfile.write(output_path, rate=sample_rate, data=output)

    return output_path

def synthesize_turkic_tts(text):
    text = util.ug_arab_to_latn(text)

    text = normalization(text, 'uyghur')
    
    with torch.no_grad():
        c_mel = text2speech(text)['feat_gen']
        wav = vocoder.inference(c_mel)
    
    output = wav.view(-1).cpu().numpy()
    print(output.shape)

    output_path = "tts_output.wav"
    scipy.io.wavfile.write(output_path, rate=22050, data=output)

    return output_path