from transformers import VitsModel, AutoTokenizer
import torch
import scipy.io.wavfile
from parallel_wavegan.utils import load_model
from espnet2.bin.tts_inference import Text2Speech
from turkicTTS_utils import normalization
import util

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load processor and model
models_info = {
    "IS2AI-TurkicTTS": None,
    "Meta-MMS": {
        "processor": AutoTokenizer.from_pretrained("facebook/mms-tts-uig-script_arabic"),
        "model": VitsModel.from_pretrained("facebook/mms-tts-uig-script_arabic"),
        "arabic_script": True
    },
    "Ixxan-FineTuned-MMS": {
        "processor": AutoTokenizer.from_pretrained("ixxan/mms-tts-uig-script_arabic-UQSpeech"),
        "model": VitsModel.from_pretrained("ixxan/mms-tts-uig-script_arabic-UQSpeech"),
        "arabic_script": True
    }
}

vocoder_checkpoint="parallelwavegan_male2_checkpoint/checkpoint-400000steps.pkl" ### specify vocoder path
vocoder = load_model(vocoder_checkpoint).to(device).eval()
vocoder.remove_weight_norm()

### specify path to the main model(transformer/tacotron2/fastspeech) and its config file
config_file = "exp/tts_train_raw_char/config.yaml"
model_path = "exp/tts_train_raw_char/train.loss.ave_5best.pth"

text2speech = Text2Speech(
    config_file,
    model_path,
    device=device, ## if cuda not available use cpu
    ### only for Tacotron 2
    threshold=0.5,
    minlenratio=0.0,
    maxlenratio=10.0,
    use_att_constraint=True,
    backward_window=1,
    forward_window=3,
    ### only for FastSpeech & FastSpeech2
    speed_control_alpha=1.0,
)
text2speech.spc2wav = None  ### disable griffin-lim

def synthesize(text, model_id):
    print(text)
    # if len(text) > 200:
    #     raise ValueError(f"Input text exceeds 200 characters. Please provide a shorter input text for faster processing.")
    
    if model_id == 'IS2AI-TurkicTTS':
        return synthesize_turkic_tts(text)
    
    if models_info[model_id]["arabic_script"]:
        text = util.ug_latn_to_arab(text)
    processor = models_info[model_id]["processor"]
    model = models_info[model_id]["model"].to(device)
    inputs = processor(text, return_tensors="pt").to(device)

    with torch.no_grad():
        output = model(**inputs).waveform.cpu().numpy()[0]  # Move output back to CPU for saving
    
    output_path = "tts_output.wav"
    sample_rate = model.config.sampling_rate
    scipy.io.wavfile.write(output_path, rate=sample_rate, data=output)

    return output_path

def synthesize_turkic_tts(text):
    text = util.ug_arab_to_latn(text)

    text = normalization(text, 'uyghur')
    
    with torch.no_grad():
        c_mel = text2speech(text)['feat_gen']
        wav = vocoder.inference(c_mel)
    
    output = wav.view(-1).cpu().numpy()
    output_path = "tts_output.wav"
    scipy.io.wavfile.write(output_path, rate=22050, data=output)

    return output_path