# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import re
import tempfile
import torch
import sys
import gradio as gr
import numpy as np

from huggingface_hub import hf_hub_download

# Setup TTS env
if "vits" not in sys.path:
    sys.path.append("vits")

from vits import commons, utils
from vits.models import SynthesizerTrn


TTS_LANGUAGES = {}
with open(f"data/tts/all_langs.tsv") as f:
    for line in f:
        iso, name = line.split(" ", 1)
        TTS_LANGUAGES[iso.strip()] = name.strip()


class TextMapper(object):
    def __init__(self, vocab_file):
        self.symbols = [
            x.replace("\n", "") for x in open(vocab_file, encoding="utf-8").readlines()
        ]
        self.SPACE_ID = self.symbols.index(" ")
        self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
        self._id_to_symbol = {i: s for i, s in enumerate(self.symbols)}

    def text_to_sequence(self, text, cleaner_names):
        """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
        Args:
        text: string to convert to a sequence
        cleaner_names: names of the cleaner functions to run the text through
        Returns:
        List of integers corresponding to the symbols in the text
        """
        sequence = []
        clean_text = text.strip()
        for symbol in clean_text:
            symbol_id = self._symbol_to_id[symbol]
            sequence += [symbol_id]
        return sequence

    def uromanize(self, text, uroman_pl):
        iso = "xxx"
        with tempfile.NamedTemporaryFile() as tf, tempfile.NamedTemporaryFile() as tf2:
            with open(tf.name, "w") as f:
                f.write("\n".join([text]))
            cmd = f"perl " + uroman_pl
            cmd += f" -l {iso} "
            cmd += f" < {tf.name} > {tf2.name}"
            os.system(cmd)
            outtexts = []
            with open(tf2.name) as f:
                for line in f:
                    line = re.sub(r"\s+", " ", line).strip()
                    outtexts.append(line)
            outtext = outtexts[0]
        return outtext

    def get_text(self, text, hps):
        text_norm = self.text_to_sequence(text, hps.data.text_cleaners)
        if hps.data.add_blank:
            text_norm = commons.intersperse(text_norm, 0)
        text_norm = torch.LongTensor(text_norm)
        return text_norm

    def filter_oov(self, text, lang=None):
        text = self.preprocess_char(text, lang=lang)
        val_chars = self._symbol_to_id
        txt_filt = "".join(list(filter(lambda x: x in val_chars, text)))
        return txt_filt

    def preprocess_char(self, text, lang=None):
        """
        Special treatement of characters in certain languages
        """
        if lang == "ron":
            text = text.replace("ț", "ţ")
            print(f"{lang} (ț -> ţ): {text}")
        return text


def synthesize(text=None, lang=None, speed=None):
    if speed is None:
        speed = 1.0

    lang_code = lang.split()[0].strip()

    vocab_file = hf_hub_download(
        repo_id="facebook/mms-tts",
        filename="vocab.txt",
        subfolder=f"models/{lang_code}",
    )
    config_file = hf_hub_download(
        repo_id="facebook/mms-tts",
        filename="config.json",
        subfolder=f"models/{lang_code}",
    )
    g_pth = hf_hub_download(
        repo_id="facebook/mms-tts",
        filename="G_100000.pth",
        subfolder=f"models/{lang_code}",
    )

    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif (
        hasattr(torch.backends, "mps")
        and torch.backends.mps.is_available()
        and torch.backends.mps.is_built()
    ):
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    print(f"Run inference with {device}")

    assert os.path.isfile(config_file), f"{config_file} doesn't exist"
    hps = utils.get_hparams_from_file(config_file)
    text_mapper = TextMapper(vocab_file)
    net_g = SynthesizerTrn(
        len(text_mapper.symbols),
        hps.data.filter_length // 2 + 1,
        hps.train.segment_size // hps.data.hop_length,
        **hps.model,
    )
    net_g.to(device)
    _ = net_g.eval()

    _ = utils.load_checkpoint(g_pth, net_g, None)

    is_uroman = hps.data.training_files.split(".")[-1] == "uroman"

    if is_uroman:
        uroman_dir = "uroman"
        assert os.path.exists(uroman_dir)
        uroman_pl = os.path.join(uroman_dir, "bin", "uroman.pl")
        text = text_mapper.uromanize(text, uroman_pl)

    text = text.lower()
    text = text_mapper.filter_oov(text, lang=lang)
    stn_tst = text_mapper.get_text(text, hps)
    with torch.no_grad():
        x_tst = stn_tst.unsqueeze(0).to(device)
        x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
        hyp = (
            net_g.infer(
                x_tst,
                x_tst_lengths,
                noise_scale=0.667,
                noise_scale_w=0.8,
                length_scale=1.0 / speed,
            )[0][0, 0]
            .cpu()
            .float()
            .numpy()
        )

    hyp = (hyp * 32768).astype(np.int16)
    return (hps.data.sampling_rate, hyp), text


TTS_EXAMPLES = [
    ["I am going to the store.", "eng (English)", 1.0],
    ["안녕하세요.", "kor (Korean)", 1.0],
    ["क्या मुझे पीने का पानी मिल सकता है?", "hin (Hindi)", 1.0],
    ["Tanış olmağıma çox şadam", "azj-script_latin (Azerbaijani, North)", 1.0],
    ["Mu zo murna a cikin ƙasar.", "hau (Hausa)", 1.0],
]