import os
import re

import yaml
import torch
import torchaudio
import numpy as np

import audiosr.latent_diffusion.modules.phoneme_encoder.text as text
from audiosr.latent_diffusion.models.ddpm import LatentDiffusion
from audiosr.latent_diffusion.util import get_vits_phoneme_ids_no_padding
from audiosr.utils import (
    default_audioldm_config,
    download_checkpoint,
    read_audio_file,
    lowpass_filtering_prepare_inference,
    wav_feature_extraction,
)
import os


def seed_everything(seed):
    import random, os
    import numpy as np
    import torch

    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


def text2phoneme(data):
    return text._clean_text(re.sub(r"<.*?>", "", data), ["english_cleaners2"])


def text_to_filename(text):
    return text.replace(" ", "_").replace("'", "_").replace('"', "_")


def extract_kaldi_fbank_feature(waveform, sampling_rate, log_mel_spec):
    norm_mean = -4.2677393
    norm_std = 4.5689974

    if sampling_rate != 16000:
        waveform_16k = torchaudio.functional.resample(
            waveform, orig_freq=sampling_rate, new_freq=16000
        )
    else:
        waveform_16k = waveform

    waveform_16k = waveform_16k - waveform_16k.mean()
    fbank = torchaudio.compliance.kaldi.fbank(
        waveform_16k,
        htk_compat=True,
        sample_frequency=16000,
        use_energy=False,
        window_type="hanning",
        num_mel_bins=128,
        dither=0.0,
        frame_shift=10,
    )

    TARGET_LEN = log_mel_spec.size(0)

    # cut and pad
    n_frames = fbank.shape[0]
    p = TARGET_LEN - n_frames
    if p > 0:
        m = torch.nn.ZeroPad2d((0, 0, 0, p))
        fbank = m(fbank)
    elif p < 0:
        fbank = fbank[:TARGET_LEN, :]

    fbank = (fbank - norm_mean) / (norm_std * 2)

    return {"ta_kaldi_fbank": fbank}  # [1024, 128]


def make_batch_for_super_resolution(input_file, waveform=None, fbank=None):
    log_mel_spec, stft, waveform, duration, target_frame = read_audio_file(input_file)

    batch = {
        "waveform": torch.FloatTensor(waveform),
        "stft": torch.FloatTensor(stft),
        "log_mel_spec": torch.FloatTensor(log_mel_spec),
        "sampling_rate": 48000,
    }

    # print(batch["waveform"].size(), batch["stft"].size(), batch["log_mel_spec"].size())

    batch.update(lowpass_filtering_prepare_inference(batch))

    assert "waveform_lowpass" in batch.keys()
    lowpass_mel, lowpass_stft = wav_feature_extraction(
        batch["waveform_lowpass"], target_frame
    )
    batch["lowpass_mel"] = lowpass_mel

    for k in batch.keys():
        if type(batch[k]) == torch.Tensor:
            batch[k] = torch.FloatTensor(batch[k]).unsqueeze(0)

    return batch, duration


def round_up_duration(duration):
    return int(round(duration / 2.5) + 1) * 2.5


def build_model(ckpt_path=None, config=None, device=None, model_name="basic"):
    if device is None or device == "auto":
        if torch.cuda.is_available():
            device = torch.device("cuda:0")
        elif torch.backends.mps.is_available():
            device = torch.device("mps")
        else:
            device = torch.device("cpu")

    print("Loading AudioSR: %s" % model_name)
    print("Loading model on %s" % device)

    ckpt_path = download_checkpoint(model_name)

    if config is not None:
        assert type(config) is str
        config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
    else:
        config = default_audioldm_config(model_name)

    # # Use text as condition instead of using waveform during training
    config["model"]["params"]["device"] = device
    # config["model"]["params"]["cond_stage_key"] = "text"

    # No normalization here
    latent_diffusion = LatentDiffusion(**config["model"]["params"])

    resume_from_checkpoint = ckpt_path

    checkpoint = torch.load(resume_from_checkpoint, map_location=device)

    latent_diffusion.load_state_dict(checkpoint["state_dict"], strict=False)

    latent_diffusion.eval()
    latent_diffusion = latent_diffusion.to(device)

    return latent_diffusion


def super_resolution(
    latent_diffusion,
    input_file,
    seed=42,
    ddim_steps=200,
    guidance_scale=3.5,
    latent_t_per_second=12.8,
    config=None,
):
    seed_everything(int(seed))
    waveform = None

    batch, duration = make_batch_for_super_resolution(input_file, waveform=waveform)

    with torch.no_grad():
        waveform = latent_diffusion.generate_batch(
            batch,
            unconditional_guidance_scale=guidance_scale,
            ddim_steps=ddim_steps,
            duration=duration,
        )

    return waveform