import EasyDel
import jax.lax
from EasyDel import JAXServer, get_mesh
from fjutils import get_float_dtype_by_name
from EasyDel.transform import llama_from_pretrained
from transformers import AutoTokenizer
import gradio as gr
from fjutils.tracker import initialise_tracking, get_mem
import argparse
from fjutils import make_shard_and_gather_fns, match_partition_rules
import threading
import typing
import IPython
import logging
import jax.numpy as jnp
import time

logging.basicConfig(
    level=logging.INFO
)


instruct = 'Context:\n{context}\nQuestion:\nYes or No question, can you answer to ' \
           '""{question}?"" only and only by using provided context?'


DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant. Always answer " \
                            "as helpfully as possible, while being safe.  Your answers should not" \
                            " include any harmful, unethical, racist, sexist, toxic, dangerous, or " \
                            "illegal content. Please ensure that your responses are socially unbiased " \
                            "and positive in nature.\nIf a question does not make any sense, or is not " \
                            "factually coherent, explain why instead of answering something not correct. If " \
                            "you don't know the answer to a question, please don't share false information."


def get_prompt_llama2_format(message: str, chat_history,
               system_prompt: str) -> str:
    texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
    do_strip = False
    for user_input, response in chat_history:
        user_input = user_input.strip() if do_strip else user_input
        do_strip = True
        texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
    message = message.strip() if do_strip else message
    texts.append(f'{message} [/INST]')
    return ''.join(texts)


class InTimeDataFinderJaxServerLlama2Type(JAXServer):
    def __init__(self, config=None):
        super().__init__(config=config)

    @classmethod
    def load_from_torch(cls, repo_id, config=None):
        with jax.default_device(jax.devices('cpu')[0]):
            param, config_model = llama_from_pretrained(
                repo_id
            )
        tokenizer = AutoTokenizer.from_pretrained(repo_id)
        model = EasyDel.FlaxLlamaForCausalLM(
            config=config_model,
            dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'),
            param_dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'),
            precision=jax.lax.Precision('fastest'),
            _do_init=False
        )
        return cls.load_from_params(
            config_model=config_model,
            model=model,
            config=config,
            params=param,
            tokenizer=tokenizer,
            add_param_field=True,
            do_memory_log=False
        )

    @classmethod
    def load_from_jax(cls, repo_id, checkpoint_path, config_repo=None, config=None):
        from huggingface_hub import hf_hub_download
        path = hf_hub_download(repo_id, checkpoint_path)
        tokenizer = AutoTokenizer.from_pretrained(repo_id)
        config_model = EasyDel.LlamaConfig.from_pretrained(config_repo or repo_id)
        model = EasyDel.FlaxLlamaForCausalLM(
            config=config_model,
            dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'),
            param_dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'),
            precision=jax.lax.Precision('fastest'),
            _do_init=False
        )
        return cls.load(
            path=path,
            config_model=config_model,
            model=model,
            config=config,
            tokenizer=tokenizer,
            add_param_field=True,
            do_memory_log=False
        )

    def process_gradio_chat(self, prompt, history, max_new_tokens, greedy, pbar=gr.Progress()):
        string = get_prompt_llama2_format(
            message=prompt,
            chat_history=history,
            system_prompt=DEFAULT_SYSTEM_PROMPT
        )
        if not self.config.stream_tokens_for_gradio:
            response, _ = self.process(
                string=string,
                greedy=greedy,
                max_new_tokens=max_new_tokens,
            )
            history.append([prompt, response])
        else:
            history.append([prompt, ''])
            for response, _ in self.process(
                    string=string,
                    greedy=greedy,
                    max_new_tokens=max_new_tokens,
                    stream=True
            ):
                history[-1][-1] = response
                yield '', history
        return '', history

    def process_gradio_instruct(self, prompt, system, max_new_tokens, greedy, pbar=gr.Progress()):
        string = get_prompt_llama2_format(system_prompt=DEFAULT_SYSTEM_PROMPT, message=prompt, chat_history=[])
        if not self.config.stream_tokens_for_gradio:
            response, _ = self.process(
                string=string,
                greedy=greedy,
                max_new_tokens=max_new_tokens,
            )
        else:
            response = ''
            for response, _ in self.process(
                    string=string,
                    greedy=greedy,
                    max_new_tokens=max_new_tokens,
                    stream=True
            ):
                yield '', response
        return '', response

if __name__ == "__main__":
    
    configs = {
        "repo_id": "meta-llama/Llama-2-7b-chat-hf",
        "max_length": 4096,
        "max_new_tokens": 4096,
        "max_stream_tokens": 64,
        "dtype": 'fp16',
        "use_prefix_tokenizer": True
    }
    for key, value in configs.items():
        print('\033[1;36m{:<30}\033[1;0m : {:>30}'.format(key.replace('_', ' '), f"{value}"))
    
    
    server = InTimeDataFinderJaxServerLlama2Type.load_from_torch(
        repo_id=configs['repo_id'],
        config=configs
    )
    server.gradio_app_chat.launch(share=False)