import os
import random

import numpy as np
import librosa
import torch

import pytorch_lightning as pl
import soundfile as sf
from torch.nn.utils.rnn import pad_sequence
from transformers import T5Config, T5ForConditionalGeneration

from midi_tokenizer import MidiTokenizer, extrapolate_beat_times
from layer.input import LogMelSpectrogram, ConcatEmbeddingToMel
from preprocess.beat_quantizer import extract_rhythm, interpolate_beat_times
from utils.dsp import get_stereo


DEFAULT_COMPOSERS = {"various composer": 2052}


class TransformerWrapper(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.tokenizer = MidiTokenizer(config.tokenizer)
        self.t5config = T5Config.from_pretrained("t5-small")

        for k, v in config.t5.items():
            self.t5config.__setattr__(k, v)

        self.transformer = T5ForConditionalGeneration(self.t5config)
        self.use_mel = self.config.dataset.use_mel
        self.mel_is_conditioned = self.config.dataset.mel_is_conditioned
        self.composer_to_feature_token = config.composer_to_feature_token

        if self.use_mel and not self.mel_is_conditioned:
            self.composer_to_feature_token = DEFAULT_COMPOSERS

        if self.use_mel:
            self.spectrogram = LogMelSpectrogram()
            if self.mel_is_conditioned:
                n_dim = 512
                composer_n_vocab = len(self.composer_to_feature_token)
                embedding_offset = min(self.composer_to_feature_token.values())
                self.mel_conditioner = ConcatEmbeddingToMel(
                    embedding_offset=embedding_offset,
                    n_vocab=composer_n_vocab,
                    n_dim=n_dim,
                )
        else:
            self.spectrogram = None

        self.lr = config.training.lr

    def forward(self, input_ids, labels):
        """
        Deprecated.
        """
        rt = self.transformer(input_ids=input_ids, labels=labels)
        return rt

    @torch.no_grad()
    def single_inference(
        self,
        feature_tokens=None,
        audio=None,
        beatstep=None,
        max_length=256,
        max_batch_size=64,
        n_bars=None,
        composer_value=None,
    ):
        """
        generate a long audio sequence

        feature_tokens or audio : shape (time, )

        beatstep : shape (time, )
        - input_ids가 해당하는 beatstep 값들
        (offset 빠짐, 즉 beatstep[0] == 0)
        - beatstep[-1] : input_ids가 끝나는 지점의 시간값
        (즉 beatstep[-1] == len(y)//sr)
        """

        assert feature_tokens is not None or audio is not None
        assert beatstep is not None

        if feature_tokens is not None:
            assert len(feature_tokens.shape) == 1

        if audio is not None:
            assert len(audio.shape) == 1

        config = self.config
        PAD = self.t5config.pad_token_id
        n_bars = config.dataset.n_bars if n_bars is None else n_bars

        if beatstep[0] > 0.01:
            print(
                "inference warning : beatstep[0] is not 0 ({beatstep[0]}). all beatstep will be shifted."
            )
            beatstep = beatstep - beatstep[0]

        if self.use_mel:
            input_ids = None
            inputs_embeds, ext_beatstep = self.prepare_inference_mel(
                audio,
                beatstep,
                n_bars=n_bars,
                padding_value=PAD,
                composer_value=composer_value,
            )
            batch_size = inputs_embeds.shape[0]
        else:
            raise NotImplementedError

        # Considering GPU capacity, some sequence would not be generated at once.
        relative_tokens = list()
        for i in range(0, batch_size, max_batch_size):
            start = i
            end = min(batch_size, i + max_batch_size)

            if input_ids is None:
                _input_ids = None
                _inputs_embeds = inputs_embeds[start:end]
            else:
                _input_ids = input_ids[start:end]
                _inputs_embeds = None

            _relative_tokens = self.transformer.generate(
                input_ids=_input_ids,
                inputs_embeds=_inputs_embeds,
                max_length=max_length,
            )
            _relative_tokens = _relative_tokens.cpu().numpy()
            relative_tokens.append(_relative_tokens)

        max_length = max([rt.shape[-1] for rt in relative_tokens])
        for i in range(len(relative_tokens)):
            relative_tokens[i] = np.pad(
                relative_tokens[i],
                [(0, 0), (0, max_length - relative_tokens[i].shape[-1])],
                constant_values=PAD,
            )
        relative_tokens = np.concatenate(relative_tokens)

        pm, notes = self.tokenizer.relative_batch_tokens_to_midi(
            relative_tokens,
            beatstep=ext_beatstep,
            bars_per_batch=n_bars,
            cutoff_time_idx=(n_bars + 1) * 4,
        )

        return relative_tokens, notes, pm

    def prepare_inference_mel(self, audio, beatstep, n_bars, padding_value, composer_value=None):
        n_steps = n_bars * 4
        n_target_step = len(beatstep)
        sample_rate = self.config.dataset.sample_rate
        ext_beatstep = extrapolate_beat_times(beatstep, (n_bars + 1) * 4 + 1)

        def split_audio(audio):
            # Split audio corresponding beat intervals.
            # Each audio's lengths are different.
            # Because each corresponding beat interval times are different.
            batch = []

            for i in range(0, n_target_step, n_steps):

                start_idx = i
                end_idx = min(i + n_steps, n_target_step)

                start_sample = int(ext_beatstep[start_idx] * sample_rate)
                end_sample = int(ext_beatstep[end_idx] * sample_rate)
                feature = audio[start_sample:end_sample]
                batch.append(feature)
            return batch

        def pad_and_stack_batch(batch):
            batch = pad_sequence(batch, batch_first=True, padding_value=padding_value)
            return batch

        batch = split_audio(audio)
        batch = pad_and_stack_batch(batch)

        inputs_embeds = self.spectrogram(batch).transpose(-1, -2)
        if self.mel_is_conditioned:
            composer_value = torch.tensor(composer_value).to(self.device)
            composer_value = composer_value.repeat(inputs_embeds.shape[0])
            inputs_embeds = self.mel_conditioner(inputs_embeds, composer_value)
        return inputs_embeds, ext_beatstep

    @torch.no_grad()
    def generate(
        self,
        audio_path=None,
        composer=None,
        model="generated",
        steps_per_beat=2,
        stereo_amp=0.5,
        n_bars=2,
        ignore_duplicate=True,
        show_plot=False,
        save_midi=False,
        save_mix=False,
        midi_path=None,
        mix_path=None,
        click_amp=0.2,
        add_click=False,
        max_batch_size=None,
        beatsteps=None,
        mix_sample_rate=None,
        audio_y=None,
        audio_sr=None,
    ):
        config = self.config
        device = self.device

        if audio_path is not None:
            extension = os.path.splitext(audio_path)[1]
            mix_path = (
                audio_path.replace(extension, f".{model}.{composer}.wav")
                if mix_path is None
                else mix_path
            )
            midi_path = (
                audio_path.replace(extension, f".{model}.{composer}.mid")
                if midi_path is None
                else midi_path
            )

        max_batch_size = 64 // n_bars if max_batch_size is None else max_batch_size
        composer_to_feature_token = self.composer_to_feature_token

        if composer is None:
            composer = random.sample(list(composer_to_feature_token.keys()), 1)[0]

        composer_value = composer_to_feature_token[composer]
        mix_sample_rate = config.dataset.sample_rate if mix_sample_rate is None else mix_sample_rate

        if not ignore_duplicate:
            if os.path.exists(midi_path):
                return

        ESSENTIA_SAMPLERATE = 44100

        if beatsteps is None:
            y, sr = librosa.load(audio_path, sr=ESSENTIA_SAMPLERATE)
            (
                bpm,
                beat_times,
                confidence,
                estimates,
                essentia_beat_intervals,
            ) = extract_rhythm(audio_path, y=y)
            beat_times = np.array(beat_times)
            beatsteps = interpolate_beat_times(beat_times, steps_per_beat, extend=True)
        else:
            y = None

        if self.use_mel:
            if audio_y is None and config.dataset.sample_rate != ESSENTIA_SAMPLERATE:
                if y is not None:
                    y = librosa.core.resample(
                        y,
                        orig_sr=ESSENTIA_SAMPLERATE,
                        target_sr=config.dataset.sample_rate,
                    )
                    sr = config.dataset.sample_rate
                else:
                    y, sr = librosa.load(audio_path, sr=config.dataset.sample_rate)
            elif audio_y is not None:
                if audio_sr != config.dataset.sample_rate:
                    audio_y = librosa.core.resample(
                        audio_y, orig_sr=audio_sr, target_sr=config.dataset.sample_rate
                    )
                    audio_sr = config.dataset.sample_rate
                y = audio_y
                sr = audio_sr

            start_sample = int(beatsteps[0] * sr)
            end_sample = int(beatsteps[-1] * sr)
            _audio = torch.from_numpy(y)[start_sample:end_sample].to(device)
            fzs = None
        else:
            raise NotImplementedError

        relative_tokens, notes, pm = self.single_inference(
            feature_tokens=fzs,
            audio=_audio,
            beatstep=beatsteps - beatsteps[0],
            max_length=config.dataset.target_length * max(1, (n_bars // config.dataset.n_bars)),
            max_batch_size=max_batch_size,
            n_bars=n_bars,
            composer_value=composer_value,
        )

        for n in pm.instruments[0].notes:
            n.start += beatsteps[0]
            n.end += beatsteps[0]

        if show_plot or save_mix:
            if mix_sample_rate != sr:
                y = librosa.core.resample(y, orig_sr=sr, target_sr=mix_sample_rate)
                sr = mix_sample_rate
            if add_click:
                clicks = librosa.clicks(times=beatsteps, sr=sr, length=len(y)) * click_amp
                y = y + clicks
            pm_y = pm.fluidsynth(sr)
            stereo = get_stereo(y, pm_y, pop_scale=stereo_amp)

        if show_plot:
            import note_seq

            note_seq.plot_sequence(note_seq.midi_to_note_sequence(pm))

        if save_mix:
            sf.write(
                file=mix_path,
                data=stereo.T,
                samplerate=sr,
                format="wav",
            )

        if save_midi:
            pm.write(midi_path)

        return pm, composer, mix_path, midi_path