from typing import List, Tuple

import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import note_seq
from matplotlib.figure import Figure
from numpy import ndarray
import torch

from constants import GM_INSTRUMENTS, SAMPLE_RATE
from string_to_notes import token_sequence_to_note_sequence
from model import get_model_and_tokenizer


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the tokenizer and the model
model, tokenizer = get_model_and_tokenizer()


def create_seed_string(genre: str = "OTHER") -> str:
    """
    Creates a seed string for generating a new piece.

    Args:
        genre (str, optional): The genre of the piece. Defaults to "OTHER".

    Returns:
        str: The seed string.
    """
    if genre == "RANDOM":
        seed_string = "PIECE_START"
    else:
        seed_string = f"PIECE_START GENRE={genre} TRACK_START"
    return seed_string


def get_instruments(text_sequence: str) -> List[str]:
    """
    Extracts the list of instruments from a text sequence.

    Args:
        text_sequence (str): The text sequence.

    Returns:
        List[str]: The list of instruments.
    """
    instruments = []
    parts = text_sequence.split()
    for part in parts:
        if part.startswith("INST="):
            if part[5:] == "DRUMS":
                instruments.append("Drums")
            else:
                index = int(part[5:])
                instruments.append(GM_INSTRUMENTS[index])
    return instruments


def generate_new_instrument(seed: str, temp: float = 0.75) -> str:
    """
    Generates a new instrument sequence from a given seed and temperature.

    Args:
        seed (str): The seed string for the generation.
        temp (float, optional): The temperature for the generation, which controls the randomness. Defaults to 0.75.

    Returns:
        str: The generated instrument sequence.
    """
    seed_length = len(tokenizer.encode(seed))

    while True:
        # Encode the conditioning tokens.
        input_ids = tokenizer.encode(seed, return_tensors="pt")

        # Move the input_ids tensor to the same device as the model
        input_ids = input_ids.to(model.device)

        # Generate more tokens.
        eos_token_id = tokenizer.encode("TRACK_END")[0]
        generated_ids = model.generate(
            input_ids,
            max_new_tokens=2048,
            do_sample=True,
            temperature=temp,
            eos_token_id=eos_token_id,
        )
        generated_sequence = tokenizer.decode(generated_ids[0])

        # Check if the generated sequence contains "NOTE_ON" beyond the seed
        new_generated_sequence = tokenizer.decode(generated_ids[0][seed_length:])
        if "NOTE_ON" in new_generated_sequence:
            return generated_sequence


def get_outputs_from_string(
    generated_sequence: str, qpm: int = 120
) -> Tuple[ndarray, str, Figure, str, str]:
    """
    Converts a generated sequence into various output formats including audio, MIDI, plot, etc.

    Args:
        generated_sequence (str): The generated sequence of tokens.
        qpm (int, optional): The quarter notes per minute. Defaults to 120.

    Returns:
        Tuple[ndarray, str, Figure, str, str]: The audio waveform, MIDI file name, plot figure,
                                               instruments string, and number of tokens string.
    """
    instruments = get_instruments(generated_sequence)
    instruments_str = "\n".join(f"- {instrument}" for instrument in instruments)
    note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm)

    synth = note_seq.fluidsynth
    array_of_floats = synth(note_sequence, sample_rate=SAMPLE_RATE)
    int16_data = note_seq.audio_io.float_samples_to_int16(array_of_floats)
    fig = note_seq.plot_sequence(note_sequence, show_figure=False)
    num_tokens = str(len(generated_sequence.split()))
    audio = gr.make_waveform((SAMPLE_RATE, int16_data))
    note_seq.note_sequence_to_midi_file(note_sequence, "midi_ouput.mid")
    return audio, "midi_ouput.mid", fig, instruments_str, num_tokens


def remove_last_instrument(
    text_sequence: str, qpm: int = 120
) -> Tuple[ndarray, str, Figure, str, str, str]:
    """
    Removes the last instrument from a song string and returns the various output formats.

    Args:
        text_sequence (str): The song string.
        qpm (int, optional): The quarter notes per minute. Defaults to 120.

    Returns:
        Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
                                                    instruments string, new song string, and number of tokens string.
    """
    # We split the song into tracks by splitting on 'TRACK_START'
    tracks = text_sequence.split("TRACK_START")
    # We keep all tracks except the last one
    modified_tracks = tracks[:-1]
    # We join the tracks back together, adding back the 'TRACK_START' that was removed by split
    new_song = "TRACK_START".join(modified_tracks)

    if len(tracks) == 2:
        # There is only one instrument, so start from scratch
        audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
            text_sequence=new_song
        )
    elif len(tracks) == 1:
        # No instrument so start from empty sequence
        audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
            text_sequence=""
        )
    else:
        audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
            new_song, qpm
        )

    return audio, midi_file, fig, instruments_str, new_song, num_tokens


def regenerate_last_instrument(
    text_sequence: str, qpm: int = 120
) -> Tuple[ndarray, str, Figure, str, str, str]:
    """
    Regenerates the last instrument in a song string and returns the various output formats.

    Args:
        text_sequence (str): The song string.
        qpm (int, optional): The quarter notes per minute. Defaults to 120.

    Returns:
        Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
                                                    instruments string, new song string, and number of tokens string.
    """
    last_inst_index = text_sequence.rfind("INST=")
    if last_inst_index == -1:
        # No instrument so start from empty sequence
        audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
            text_sequence="", qpm=qpm
        )
    else:
        # Take it from the last instrument and continue generation
        next_space_index = text_sequence.find(" ", last_inst_index)
        new_seed = text_sequence[:next_space_index]
        audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
            text_sequence=new_seed, qpm=qpm
        )
    return audio, midi_file, fig, instruments_str, new_song, num_tokens


def change_tempo(
    text_sequence: str, qpm: int
) -> Tuple[ndarray, str, Figure, str, str, str]:
    """
    Changes the tempo of a song string and returns the various output formats.

    Args:
        text_sequence (str): The song string.
        qpm (int): The new quarter notes per minute.

    Returns:
        Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
                                                    instruments string, text sequence, and number of tokens string.
    """
    audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
        text_sequence, qpm=qpm
    )
    return audio, midi_file, fig, instruments_str, text_sequence, num_tokens


def generate_song(
    genre: str = "OTHER",
    temp: float = 0.75,
    text_sequence: str = "",
    qpm: int = 120,
) -> Tuple[ndarray, str, Figure, str, str, str]:
    """
    Generates a song given a genre, temperature, initial text sequence, and tempo.

    Args:
        model (AutoModelForCausalLM): The pretrained model used for generating the sequences.
        tokenizer (AutoTokenizer): The tokenizer used to encode and decode the sequences.
        genre (str, optional): The genre of the song. Defaults to "OTHER".
        temp (float, optional): The temperature for the generation, which controls the randomness. Defaults to 0.75.
        text_sequence (str, optional): The initial text sequence for the song. Defaults to "".
        qpm (int, optional): The quarter notes per minute. Defaults to 120.

    Returns:
        Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
                                                    instruments string, generated song string, and number of tokens string.
    """
    if text_sequence == "":
        seed_string = create_seed_string(genre)
    else:
        seed_string = text_sequence

    generated_sequence = generate_new_instrument(seed=seed_string, temp=temp)
    audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
        generated_sequence, qpm
    )
    return audio, midi_file, fig, instruments_str, generated_sequence, num_tokens