import gradio as gr
import time
import requests
import json
import os
from urllib3.util.retry import Retry
from requests.adapters import HTTPAdapter

API_URL = os.getenv("API_URL")
API_KEY = os.getenv("API_KEY")

print(f"API_URL: {API_URL}")
print(f"API_KEY: {API_KEY}")

url = f"{API_URL}/v1/chat/completions"

# The headers for the HTTP request
headers = {
    "accept": "application/json",
    "Content-Type": "application/json",
    "Authorization": f"Bearer {API_KEY}",
}


def is_valid_json(data):
    try:
        parsed_data = json.loads(data)
        return True, parsed_data
    except ValueError as e:
        return False, str(e)


with gr.Blocks() as demo:

    markup = gr.Markdown(
        """
                         # Phi-2
                         This is a demo of the Phi-2 quantized model in GGUF (phi-2.Q5_K_M.gguf) hosted on K8s cluster.

                         The original models can be found [MaziyarPanahi/MaziyarPanahi/phi-2-GGUF](https://huggingface.co/MaziyarPanahi/phi-2-GGUF)"""
    )
    chatbot = gr.Chatbot(height=500)
    msg = gr.Textbox(lines=1, label="User Message")
    clear = gr.Button("Clear")
    with gr.Row():

        with gr.Column(scale=2):
            system_prompt_input = gr.Textbox(
                label="System Prompt",
                placeholder="Type system prompt here...",
                value="You are a helpful assistant.",
            )
            temperature_input = gr.Slider(
                label="Temperature", minimum=0.0, maximum=1.0, value=0.9, step=0.01
            )
            max_new_tokens_input = gr.Slider(
                label="Max New Tokens", minimum=0, maximum=1024, value=256, step=1
            )

        with gr.Column(scale=2):
            top_p_input = gr.Slider(
                label="Top P", minimum=0.0, maximum=1.0, value=0.95, step=0.01
            )
            top_k_input = gr.Slider(
                label="Top K", minimum=1, maximum=100, value=50, step=1
            )
            repetition_penalty_input = gr.Slider(
                label="Repetition Penalty",
                minimum=1.0,
                maximum=2.0,
                value=1.1,
                step=0.01,
            )

    def update_globals(
        system_prompt, temperature, max_new_tokens, top_p, top_k, repetition_penalty
    ):
        global global_system_prompt, global_temperature, global_max_new_tokens, global_top_p, global_repetition_penalty, global_top_k
        global_system_prompt = system_prompt
        global_temperature = temperature
        global_max_new_tokens = max_new_tokens
        global_top_p = top_p
        global_top_k = top_k
        global_repetition_penalty = repetition_penalty

    def user(user_message, history):
        return "", history + [[user_message, None]]

    def bot(
        history,
        system_prompt,
        temperature,
        max_new_tokens,
        top_p,
        top_k,
        repetition_penalty,
    ):
        print(f"History in bot: {history}")
        print(f"System Prompt: {system_prompt}")
        print(f"Temperature: {temperature}")
        print(f"Max New Tokens: {max_new_tokens}")
        print(f"Top P: {top_p}")
        print(f"Top K: {top_k}")
        print(f"Repetition Penalty: {repetition_penalty}")

        history_messages = [{"content": h[0], "role": "user"} for h in history if h[0]]
        history[-1][1] = ""
        sys_msg = [
            {
                "content": (
                    system_prompt if system_prompt else "You are a helpful assistant."
                ),
                "role": "system",
            }
        ]
        history_messages = sys_msg + history_messages
        print(history_messages)

        # Create a session object
        session = requests.Session()

        # Define the retry strategy
        retries = Retry(
            total=5,  # Total number of retries to allow
            backoff_factor=1,  # A backoff factor to apply between attempts
            status_forcelist=[
                500,
                502,
                503,
                504,
            ],  # A set of HTTP status codes that we should force a retry on
            allowed_methods=[
                "HEAD",
                "GET",
                "OPTIONS",
                "POST",
            ],  # HTTP methods to retry on
        )
        data = {
            "messages": history_messages,
            "stream": True,
            "temprature": temperature,
            "top_k": top_k,
            "top_p": top_p,
            "seed": 42,
            "repeat_penalty": repetition_penalty,
            "chat_format": "phind",
            "max_tokens": max_new_tokens,
            # "response_format": {
            #     "type": "json_object",
            # },
        }

        # Mount it for http usage
        session.mount("http://", HTTPAdapter(max_retries=retries))

        # Making the POST request with increased timeout and retry logic
        try:
            response = session.post(
                url,
                headers=headers,
                data=json.dumps(data),
                stream=True,
                timeout=(10, 30),
            )
            if response.status_code == 200:
                for line in response.iter_lines():
                    # Filter out keep-alive new lines
                    if line:
                        data = line.decode("utf-8").lstrip("data: ")
                        # Check if the examples are valid
                        valid_check = is_valid_json(data)
                        if valid_check[0]:
                            try:
                                # Attempt to parse the JSON dataa
                                # json_data = json.loads(data)
                                json_data = valid_check[1]

                                delta_content = (
                                    json_data.get("choices", [{}])[0]
                                    .get("delta", {})
                                    .get("content", "")
                                )

                                if delta_content:  # Ensure there's content to print
                                    history[-1][1] += delta_content
                                    time.sleep(0.05)
                                    yield history
                            except json.JSONDecodeError as e:
                                print(f"Error decoding JSON: {e} date: {data}")
        except requests.exceptions.RequestException as e:
            print(f"An error occurred: {e}")

    msg.submit(
        user, [msg, chatbot], [msg, chatbot], queue=True, concurrency_limit=10
    ).then(
        bot,
        inputs=[
            chatbot,
            system_prompt_input,
            temperature_input,
            max_new_tokens_input,
            top_p_input,
            top_k_input,
            repetition_penalty_input,
        ],
        outputs=chatbot,
    )

    clear.click(lambda: None, None, chatbot, queue=False)


demo.queue(default_concurrency_limit=20, max_size=20, api_open=False)
if __name__ == "__main__":
    demo.launch(show_api=False, share=False)