import random
import torch
from slam_llm.utils.model_utils import get_custom_model_factory
from utils.snac_utils import reconscruct_snac, reconstruct_tensors, layershift
import whisper
import numpy as np
from s2s_config import InferenceConfig, CKPT_PATH, CKPT_REPO, CKPT_LOCAL_DIR, CKPT_NAME
import os
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download
from typing import Callable


def update_progress(progress_callback: Callable[[str], None] | None, message: str):
    if progress_callback:
        progress_callback(message)


def pull_model_ckpt():
    if not os.path.exists(CKPT_LOCAL_DIR):
        os.makedirs(CKPT_LOCAL_DIR)
    if os.path.exists(CKPT_PATH):
        return
    hf_hub_download(
        repo_id=CKPT_REPO,
        filename=CKPT_NAME,
        local_dir=CKPT_LOCAL_DIR,
        token=os.getenv("HF_TOKEN"),
    )


pull_model_ckpt()


def extract_audio_feature(audio_path, mel_size):
    print("Extracting audio features from", audio_path)
    audio_raw = whisper.load_audio(audio_path)
    audio_raw = whisper.pad_or_trim(audio_raw)
    audio_mel = whisper.log_mel_spectrogram(audio_raw, n_mels=mel_size).permute(1, 0)
    audio_length = (audio_mel.shape[0] + 1) // 2
    audio_length = audio_length // 5
    audio_res = audio_mel

    return audio_res, audio_length


def get_input_ids(length, special_token_a, special_token_t, vocab_config):
    input_ids = []
    for i in range(vocab_config.code_layer):
        input_ids_item = []
        input_ids_item.append(layershift(vocab_config.input_a, i))
        input_ids_item += [layershift(vocab_config.pad_a, i)] * length
        input_ids_item += [
            (layershift(vocab_config.eoa, i)),
            layershift(special_token_a, i),
        ]
        input_ids.append(torch.tensor(input_ids_item).unsqueeze(0))
    input_id_T = torch.tensor(
        [vocab_config.input_t]
        + [vocab_config.pad_t] * length
        + [vocab_config.eot, special_token_t]
    )
    input_ids.append(input_id_T.unsqueeze(0))
    return input_ids


def generate_from_wav(
    wav_path, model, codec_decoder, dataset_config, decode_config, device
):
    mel_size = dataset_config.mel_size
    prompt = dataset_config.prompt
    prompt_template = "USER: {}\n ASSISTANT: "
    vocab_config = dataset_config.vocab_config
    special_token_a = vocab_config.answer_a
    special_token_t = vocab_config.answer_t
    code_layer = vocab_config.code_layer
    task_type = dataset_config.task_type

    audio_mel, audio_length = extract_audio_feature(wav_path, mel_size)

    prompt = prompt_template.format(prompt)
    prompt_ids = model.tokenizer.encode(prompt)
    prompt_length = len(prompt_ids)
    prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64)

    example_ids = get_input_ids(
        audio_length + prompt_length, special_token_a, special_token_t, vocab_config
    )
    text_layer = example_ids[code_layer]
    text_layer = torch.cat(
        (
            text_layer[:, : audio_length + 1],
            prompt_ids.unsqueeze(0),
            text_layer[:, -2:],
        ),
        dim=1,
    )  # <bos> <audio> <prompt> <eos> <task>
    example_ids[code_layer] = text_layer

    input_length = audio_length
    example_mask = example_ids[0][0].ge(-1)
    example_ids = torch.stack(example_ids).squeeze()

    input_ids = example_ids.unsqueeze(0).to(device)
    attention_mask = example_mask.unsqueeze(0).to(device)
    audio_mel = audio_mel.unsqueeze(0).to(device)
    input_length = torch.tensor([input_length]).to(device)
    audio_length = torch.tensor([audio_length]).to(device)
    task_type = [task_type]

    modality_mask = torch.zeros_like(attention_mask)
    padding_left = 1  # +1 for <bos>
    modality_mask[0, padding_left : padding_left + audio_length] = True

    batch = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "audio_mel": audio_mel,
        "input_length": input_length,
        "audio_length": audio_length,
        "modality_mask": modality_mask,
        "task_types": task_type,
    }

    model_outputs = model.generate(**batch, **decode_config)
    text_outputs = model_outputs[7]
    audio_outputs = model_outputs[:7]
    output_text = model.tokenizer.decode(
        text_outputs, add_special_tokens=False, skip_special_tokens=True
    )

    if decode_config.decode_text_only:
        return None, output_text

    audio_tokens = [audio_outputs[layer] for layer in range(7)]
    audiolist = reconscruct_snac(audio_tokens)
    audio = reconstruct_tensors(audiolist)
    with torch.inference_mode():
        audio_hat = codec_decoder.decode(audio)

    return audio_hat, output_text


model = None
codec_decoder = None
device = None


def generate(
    wav_path: str, progress_callback: Callable[[str], None] | None = None
) -> tuple[np.ndarray, int | float]:
    global model, codec_decoder, device

    config = OmegaConf.structured(InferenceConfig())
    train_config, model_config, dataset_config, decode_config = (
        config.train_config,
        config.model_config,
        config.dataset_config,
        config.decode_config,
    )

    torch.cuda.manual_seed(train_config.seed)
    torch.manual_seed(train_config.seed)
    random.seed(train_config.seed)

    if model is None or codec_decoder is None or device is None:
        update_progress(progress_callback, "Loading model")
        model_factory = get_custom_model_factory(model_config)
        model, _ = model_factory(train_config, model_config, CKPT_PATH)
        codec_decoder = model.codec_decoder
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)
        model.eval()

    update_progress(progress_callback, "Generating")
    output_wav, output_text = generate_from_wav(
        wav_path, model, codec_decoder, dataset_config, decode_config, device
    )

    return output_wav.squeeze().cpu().numpy(), 24000


if __name__ == "__main__":
    wav_path = "sample.wav"
    generate(wav_path)