import os
import re

import streamlit as st
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

device = torch.cuda.device_count() - 1


def get_access_token():
    try:
        if not os.path.exists(".streamlit/secrets.toml"):
            raise FileNotFoundError
        access_token = st.secrets.get("babel")
    except FileNotFoundError:
        access_token = os.environ.get("HF_ACCESS_TOKEN", None)
    return access_token


@st.cache_resource
def load_model(model_name):
    os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        use_fast=("ul2" not in model_name),
        use_auth_token=get_access_token(),
    )
    if tokenizer.pad_token is None:
        print("Adding pad_token to the tokenizer")
        tokenizer.pad_token = tokenizer.eos_token
    for framework in [None, "flax", "tf"]:
        try:
            model = AutoModelForSeq2SeqLM.from_pretrained(
                model_name,
                from_flax=(framework == "flax"),
                from_tf=(framework == "tf"),
                use_auth_token=get_access_token(),
            )
            break
        except EnvironmentError:
            if framework == "tf":
                raise
    if device != -1:
        model.to(f"cuda:{device}")
    return tokenizer, model


class Generator:
    def __init__(self, model_name, task, desc, split_sentences):
        self.model_name = model_name
        self.task = task
        self.desc = desc
        self.split_sentences = split_sentences
        self.tokenizer = None
        self.model = None
        self.prefix = ""
        self.gen_kwargs = {
            "max_length": 128,
            "num_beams": 6,
            "num_beam_groups": 3,
            "no_repeat_ngram_size": 0,
            "early_stopping": True,
            "num_return_sequences": 1,
            "length_penalty": 1.0,
        }
        self.load()

    def load(self):
        print(f"Loading model {self.model_name}")
        self.tokenizer, self.model = load_model(self.model_name)

        for key in self.gen_kwargs:
            if key in self.model.config.__dict__:
                self.gen_kwargs[key] = self.model.config.__dict__[key]
        try:
            if self.task in self.model.config.task_specific_params:
                task_specific_params = self.model.config.task_specific_params[
                    self.task
                ]
                if "prefix" in task_specific_params:
                    self.prefix = task_specific_params["prefix"]
                for key in self.gen_kwargs:
                    if key in task_specific_params:
                        self.gen_kwargs[key] = task_specific_params[key]
        except TypeError:
            pass

    def generate(self, text: str, streamer=None, **generate_kwargs) -> (str, dict):
        # Replace two or more newlines with a single newline in text
        text = re.sub(r"\n{2,}", "\n", text)

        generate_kwargs = {**self.gen_kwargs, **generate_kwargs}

        # if there are newlines in the text, and the model needs line-splitting, split the text and recurse
        if re.search(r"\n", text) and self.split_sentences:
            lines = text.splitlines()
            translated = [
                self.generate(line, streamer, **generate_kwargs)[0] for line in lines
            ]
            return "\n".join(translated), generate_kwargs

        # if self.tokenizer has a newline_token attribute, replace \n with it
        if hasattr(self.tokenizer, "newline_token"):
            text = re.sub(r"\n", self.tokenizer.newline_token, text)

        batch_encoded = self.tokenizer(
            self.prefix + text,
            max_length=generate_kwargs["max_length"],
            padding=False,
            truncation=False,
            return_tensors="pt",
        )
        if device != -1:
            batch_encoded.to(f"cuda:{device}")
        logits = self.model.generate(
            batch_encoded["input_ids"],
            attention_mask=batch_encoded["attention_mask"],
            streamer=streamer,
            **generate_kwargs,
        )
        decoded_preds = self.tokenizer.batch_decode(
            logits.cpu().numpy(), skip_special_tokens=False
        )

        def replace_tokens(pred):
            pred = pred.replace("<pad> ", "").replace("<pad>", "").replace("</s>", "")
            if hasattr(self.tokenizer, "newline_token"):
                pred = pred.replace(self.tokenizer.newline_token, "\n")
            return pred

        decoded_preds = list(map(replace_tokens, decoded_preds))
        return decoded_preds[0], generate_kwargs

    def __str__(self):
        return self.model_name


class GeneratorFactory:
    def __init__(self, generator_list):
        self.generators = []
        for g in generator_list:
            with st.spinner(text=f"Loading the model {g['desc']} ..."):
                self.add_generator(**g)

    def add_generator(self, model_name, task, desc, split_sentences):
        # If the generator is not yet present, add it
        if not self.get_generator(model_name=model_name, task=task, desc=desc):
            g = Generator(model_name, task, desc, split_sentences)
            self.generators.append(g)

    def get_generator(self, **kwargs):
        for g in self.generators:
            if all([g.__dict__.get(k) == v for k, v in kwargs.items()]):
                return g
        return None

    def __iter__(self):
        return iter(self.generators)

    def filter(self, **kwargs):
        return [
            g
            for g in self.generators
            if all([g.__dict__.get(k) == v for k, v in kwargs.items()])
        ]