import streamlit as st
from transformers import AutoTokenizer
import json
import tempfile
import os
import uuid
import copy
import shutil

st.set_page_config(layout="wide")

def sanitize_jinja2(jinja_lines):
    
    one_liner_jinja = ""
    for line in jinja_lines:
        one_liner_jinja += line.lstrip(" ").rstrip("\n")

    return one_liner_jinja

@st.cache_resource
def get_existing_templates():
    return [None] + os.listdir("./templates")

# if os.path.exists("./tmp"):
#     if len(os.listdir("./tmp")) > 20:
#         shutil.rmtree('./tmp')

# Initialization
if 'tokenizer_json' not in st.session_state:
    st.session_state['tokenizer_json'] = None

if 'tokenizer' not in st.session_state:
    st.session_state['tokenizer'] = None

if 'repo_normalized_name' not in st.session_state:
    st.session_state['repo_normalized_name'] = None

if 'repo_id' not in st.session_state:
    st.session_state['repo_id'] = None

if 'input_jinja_template' not in st.session_state:
    st.session_state['input_jinja_template'] = ""

if 'uuid' not in st.session_state:
    st.session_state['uuid'] = uuid.uuid4()
    os.makedirs(f"./tmp/{st.session_state['uuid']}")

if 'successful_template' not in st.session_state:
    st.session_state['successful_template'] = ''

if 'generated_prompt_w_add_generation_prompt' not in st.session_state:
    st.session_state['generated_prompt_w_add_generation_prompt'] = ''
    
if 'generated_prompt_wo_add_generation_prompt' not in st.session_state:
    st.session_state['generated_prompt_wo_add_generation_prompt'] = ''

if not os.path.exists("./tmp"):
    os.makedirs("./tmp")

title_description = """
Chat Template Generation: Make Chat Easier with Huggingface Tokenizer
"""

st.title(title_description)
st.markdown('This streamlit app is to serve as an easier way to check and push the chat template to your/exisiting huggingface repo')

list_of_templates = get_existing_templates()
with st.expander("Current predefined templates"):
    for model in list_of_templates[1:]:
        st.markdown(f"- {model}")
    st.info('More templates will be predefined for easier setup of chat template.', icon="ℹ️")

st.divider()
# custom_repo_tab, prebuilt_template_tab = st.tabs(["Specify Custom Repository Path", "Select Prebuilt Template"])

hf_model_repo_name = st.text_input("Hugging Face Model Repository To Update", value="tiiuae/falcon-7b", max_chars=None, key=None, type="default", 
                        help=None, autocomplete=None, label_visibility="visible")

gen_button = st.button("Get Tokenizer Config")

if gen_button:
    with st.spinner(text="In progress...", cache=False):
        st.session_state['repo_id'] = hf_model_repo_name
        st.session_state['tokenizer'] = AutoTokenizer.from_pretrained(hf_model_repo_name)

        st.session_state['repo_normalized_name'] = hf_model_repo_name.replace("/", "_")
        st.session_state['tokenizer_json'] = f"./tmp/{st.session_state['uuid']}_{hf_model_repo_name}"
        # st.session_state['tokenizer'].save_pretrained(st.session_state['tokenizer_json'])
    
if st.session_state['tokenizer_json'] is not None:
    st.session_state['tokenizer'].save_pretrained(st.session_state['tokenizer_json'])
    with open(f"{st.session_state['tokenizer_json']}/tokenizer_config.json", "rb") as f:
        tokenizer_json = json.load(f)
    shutil.rmtree(st.session_state['tokenizer_json'])

    json_spec, col2 = st.columns(spec=[0.3, 0.7])


    with json_spec:
        st.markdown(f"### Tokenizer Config from {st.session_state['repo_normalized_name']}")
        st.json(json.dumps(tokenizer_json, indent=4))

    with col2:
        chat = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Hello, how are you?"},
        {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
        {"role": "user", "content": "I'd like to show off how chat templating works!"},
        ]
        st.markdown("### Example Conversation")
        st.json(json.dumps(chat, indent=4), expanded=False)

        prompt_template_col, prompt_template_output_col = st.columns(spec=[0.3, 0.7])

        with prompt_template_col:
            list_of_templates = get_existing_templates()
            selected_template = st.selectbox("Choose Existing Template or Leave Blank. (If template is None, it will check current tokenizer's `chat_template` and `default_chat_template` fields)", 
                options=list_of_templates, 
                index=0, placeholder="Choose a template (If template is None, it will check current tokenizer `chat_template` and `default_chat_template` fields)", disabled=False, label_visibility="visible")
            # add_generation_prompt_checkbox = st.checkbox("add_generation_prompt")
            generate_prompt_example_button = st.button("Generate Prompt", key="generate_prompt_example_button")

            # if selected_template is None:
            #     st.session_state['input_jinja_template'] = st.text_area(
            #         "Jinja Chat Template", value=st.session_state['input_jinja_template'], 
            #         height=500, placeholder=None, disabled=False, label_visibility="visible")

            if selected_template is not None:
                with open(f"./templates/{selected_template}", "r") as f:
                    jinja_lines = f.readlines()
                    st.session_state['input_jinja_template'] = "".join(jinja_lines)

            if selected_template is None:
                st.session_state['input_jinja_template'] = st.session_state['tokenizer'].chat_template 
                if st.session_state['input_jinja_template'] is None:
                    st.session_state['input_jinja_template'] = st.session_state['tokenizer'].default_chat_template


            st.session_state['input_jinja_template'] = st.text_area(
                "Jinja Chat Template", value=st.session_state['input_jinja_template'], 
                height=500, placeholder=None, disabled=False, label_visibility="visible")

        
        with prompt_template_output_col:
            # print(st.session_state['input_jinja_template'])
            if generate_prompt_example_button:
                with open(f"./tmp/{st.session_state['uuid']}/tmp_chat_template.json", "w") as fp:
                    fp.write(st.session_state['input_jinja_template'])
                with open(f"./tmp/{st.session_state['uuid']}/tmp_chat_template.json", "r") as f:
                    jinja_lines = copy.deepcopy(f.readlines())
                    st.session_state['tokenizer'].chat_template = sanitize_jinja2(jinja_lines)
                    # print(sanitize_jinja2(jinja_lines))
                os.remove(f"./tmp/{st.session_state['uuid']}/tmp_chat_template.json")
                st.session_state['generated_prompt_wo_add_generation_prompt'] = st.session_state['tokenizer'].apply_chat_template(chat, tokenize=False, add_generation_prompt= False)
                st.session_state['generated_prompt_w_add_generation_prompt'] = st.session_state['tokenizer'].apply_chat_template(chat, tokenize=False, add_generation_prompt= True)
                # print(generated_prompt_wo_add_generation_prompt)
                st.session_state['successful_template'] = copy.deepcopy(st.session_state['input_jinja_template'])
                # print(st.session_state['successful_template'])

            if len(st.session_state['successful_template']) > 0:
                st.text_area(
                    "Generate Prompt with `add_generation_prompt=False`", value=st.session_state['generated_prompt_wo_add_generation_prompt'], 
                    height=300, placeholder=None, disabled=True, label_visibility="visible", key="generated_prompt_wo_add_generation_prompt_text_area")

                st.text_area(
                    "Generate Prompt with `add_generation_prompt=True`", value=st.session_state['generated_prompt_w_add_generation_prompt'], 
                    height=300, placeholder=None, disabled=True, label_visibility="visible", key="generated_prompt_w_add_generation_prompt_text_area")

                access_token_no_cache = st.text_input("HuggingFace Access Token API with Write Access", type="password", key="access_token_no_cache")
                commit_message_text_input = st.text_input("Commit Message", key="commit_message_text_input")
                to_private_checkbox = st.checkbox("To Private Repo", key="to_private_checkbox")
                create_pr_checkbox = st.checkbox("Create PR (Check to contribute to others' model repository 🤗)", key="create_pr_checkbox")
                push_to_hub_button = st.button("Push to Hub", key="push_to_hub_button", use_container_width=True)
                st.session_state['tokenizer'].save_pretrained(st.session_state['tokenizer_json'])
                with open(f"{st.session_state['tokenizer_json']}/tokenizer_config.json", "r") as f:

                    tokenizer_config_content = json.loads(f.read())
                shutil.rmtree(st.session_state['tokenizer_json'])

                st.download_button(
                    label="Download tokenizer_config.json",
                    data=json.dumps(tokenizer_config_content, indent=4),
                    file_name='tokenizer_config.json',
                    mime='application/json',
                    use_container_width=True
                )
                st.download_button(
                    label="Download chat_template.jinja2",
                    data=st.session_state['successful_template'],
                    file_name='chat_template.jinja2',
                    mime='text/plain',
                    use_container_width=True
                )
                if push_to_hub_button:
                    with open(f"./tmp/{st.session_state['uuid']}/tmp_chat_template.json", "w") as fp:
                        fp.write(st.session_state['successful_template'])
                    with open(f"./tmp/{st.session_state['uuid']}/tmp_chat_template.json", "r") as f:
                        successful_jinja_lines = f.readlines()
                        st.session_state['tokenizer'].chat_template = sanitize_jinja2(successful_jinja_lines)
                        try:
                            with st.spinner(text="Pushing to hub ...", cache=False):
                                st.session_state['tokenizer'].push_to_hub(
                                    repo_id=st.session_state['repo_id'], 
                                    commit_message=commit_message_text_input, 
                                    private=to_private_checkbox, 
                                    token=access_token_no_cache,
                                    create_pr=create_pr_checkbox)
                        except Exception as e:
                            st.write(f"Repo id: {st.session_state['repo_id']}")
                            st.write(str(e))
                    os.remove(f"./tmp/{st.session_state['uuid']}/tmp_chat_template.json")