from __future__ import annotations

from typing import List, Optional

import numpy as np
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast

from mario_gpt.level import FULL_LEVEL_STR_WITH_PATHS

DEFAULT_MODEL = "distilgpt2"


def split_given_size(a, size):
    return np.split(a, np.arange(size, len(a), size))


def flip_and_transpose(arr: np.array, flip_first: bool = False):
    if arr.shape[-1] > 1:
        if flip_first:
            return np.flip(arr, -1).transpose()
        return np.flip(arr.transpose(), -1)
    return arr


def join_list_of_list(str_lists):
    return ["".join(s) for s in str_lists]


def characterize(str_lists):
    return [list(s) for s in str_lists]


class MarioDataset(Dataset):
    def __init__(
        self,
        tokenizer: Optional[PreTrainedTokenizer] = None,
        level_string: Optional[str] = None,
        context_len: int = 700,
        height: int = 14,
        remove_start_end_tokens: bool = False,
        sample_all_indices: bool = False,
    ):
        if level_string is None:
            print(
                "No level string specified, using default string FULL_LEVEL_STR_WITH_PATHS..."
            )
            level_string = FULL_LEVEL_STR_WITH_PATHS
        elif ".txt" in level_string:
            with open(level_string, "r") as file:
                level_string = file.read()

        self.character_set = set(level_string)
        if "\n" in self.character_set:
            self.character_set.remove("\n")
        self.vocab_size = len(self.character_set)
        self.sample_all_indices = sample_all_indices

        def get_training_corpus():
            yield list(level_string)

        if tokenizer is None:
            tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL)

        self.tokenizer = tokenizer
        if getattr(tokenizer, "train_new_from_iterator", None) is not None:
            self.tokenizer = tokenizer.train_new_from_iterator(
                get_training_corpus(), 52000
            )
        elif getattr(tokenizer, "train_from_iterator", None) is not None:
            self.tokenizer = PreTrainedTokenizerFast(tokenizer_object=self.tokenizer)
            self.tokenizer = self.tokenizer.train_new_from_iterator(
                get_training_corpus(), self.vocab_size
            )
        self.context_len = context_len
        self.height = height

        x, self.str_arr = self.convert_level_to_tensor(level_string.split("\n"))
        self.input_ids = x["input_ids"].squeeze()
        self.attention_masks = x["attention_mask"].squeeze()
        if remove_start_end_tokens:
            self.input_ids = self.input_ids[1:-1]
            self.attention_masks = self.attention_masks[1:-1]

        self.indices = self.generate_indices()

        self.unique_tokens, self.unique_counts = self.input_ids.unique(
            return_counts=True
        )
        self.weighted_unique_counts = (
            1.0 / self.unique_counts / torch.sum(self.unique_counts)
        )

        self.token_dict = {}
        string_tokens = list(self.tokenizer.decode(self.unique_tokens))
        for int_token, string_token in zip(self.unique_tokens, string_tokens):
            self.token_dict[string_token] = int_token

    def convert_level_to_tensor(self, level: List[str]):
        str_arr = flip_and_transpose(np.array(characterize(level)))
        str_arr = "".join(join_list_of_list(str_arr))

        x = self.tokenizer(str_arr, return_tensors="pt")
        return x, str_arr

    def __len__(self):
        return self.indices.shape[0]

    def __getitem__(self, idx):
        indices = self.indices[idx]
        return self.input_ids[indices], self.attention_masks[indices]

    def generate_indices(self):
        out = []
        for idx in range(self.input_ids.shape[0] - self.context_len):
            if idx % self.height == 0 or self.sample_all_indices:
                arange = torch.arange(idx, idx + self.context_len)
                out.append(arange)
        return torch.stack(out)

    def sample_indices(self, batch_size):
        out = []
        for _ in range(batch_size):
            start_idx = np.random.randint(0, self.__len__() - self.context_len)
            indices = torch.arange(start_idx, start_idx + self.context_len)
            out.append(indices)
        return torch.stack(out)

    def __str__(self):
        str_list = characterize(self.tokenizer.batch_decode(self.x["input_ids"]))
        string = "\n".join(
            join_list_of_list(flip_and_transpose(np.array(str_list), True))
        )
        return string

    def generate_mask(self, mask_len: int, batch_size: int = 1):
        mask_token = self.tokenizer("<mask>").input_ids[1]
        ones = torch.ones((batch_size, mask_len))
        return ones * mask_token

    def apply_mask(self, level, masked_indices, mask=None):
        if len(level.shape) == 1:
            level = level.unsqueeze(0)
        batch_size = level.shape[0]
        mask_len = masked_indices.shape[-1]
        if mask is None:
            mask = self.generate_mask(mask_len, batch_size)
        mask = mask.long().to(level.device)
        masked_level = level * torch.ones_like(level).to(level.device)
        masked_level[:, masked_indices] = mask
        return masked_level