from typing import Any, Callable, List, Tuple

import huggingface_hub

from dataclasses import dataclass

from datetime import datetime

from time import sleep

import inspect

import ast

from random import randint

from urllib.parse import quote

from black import Mode, format_str

import gradio as gr

from huggingface_hub import InferenceClient

from pymongo import MongoClient

from constants import *
from config import DemoConfig
from tools import Tools


@dataclass
class Function:
    name: str
    short_description: str
    description_function: Callable[[Any], str]
    explanation_function: Callable[[Any], str]


FUNCTIONS = [
    Function(
        name="get_current_location",
        short_description="Finding your city",
        description_function=lambda *_, **__: "Finding your city",
        explanation_function=lambda result: f"Found you in {result}!",
    ),
    Function(
        name="sort_results",
        short_description="Sorting results",
        description_function=lambda places, sort, descending=True, first_n=None: f"Sorting results by {sort} from "
        + ("lowest to highest" if not descending else "highest to lowest"),
        explanation_function=lambda result: "Done!",
    ),
    Function(
        name="get_latitude_longitude",
        short_description="Convert to coordinates",
        description_function=lambda location: f"Converting {location} into latitude and longitude coordinates",
        explanation_function=lambda result: "Converted!",
    ),
    Function(
        name="get_distance",
        short_description="Calcuate distance",
        description_function=lambda place_1, place_2: "Calculating distances",
        explanation_function=lambda result: result[2],
    ),
    Function(
        name="get_recommendations",
        short_description="Read recommendations",
        description_function=lambda topics, **__: f"Reading recommendations for the following "
        + (
            f"topics: {', '.join(topics)}" if len(topics) > 1 else f"topic: {topics[0]}"
        ),
        explanation_function=lambda result: f"Read {len(result)} recommendations",
    ),
    Function(
        name="find_places_near_location",
        short_description="Look for places",
        description_function=lambda type_of_place, location, radius_miles=50: f"Looking for places near {location} within {radius_miles} with the following "
        + (
            f"types: {', '.join(type_of_place)}"
            if isinstance(type_of_place, list)
            else f"type: {type_of_place}"
        ),
        explanation_function=lambda result: f"Found "
        + (f"{len(result)} places!" if len(result) > 1 else f"1 place!"),
    ),
    Function(
        name="get_some_reviews",
        short_description="Fetching reviews",
        description_function=lambda place_names, **_: f"Fetching reviews for the requested items",
        explanation_function=lambda result: f"Fetched {len(result)} reviews!",
    ),
    Function(
        name="out_of_domain",
        short_description="The provided query does not relate to locations, reviews, or recommendations.",
        description_function=lambda user_query : "Irrelevant query detected.",
        explanation_function = lambda user_query : "Irrelevant query detected."
    )
]


class FunctionsHelper:
    FUNCTION_DEFINITION_TEMPLATE = '''Function:
def {name}{signature}:
"""
{docstring}
"""

'''
    PROMPT_TEMPLATE = \
"""
{function_definitions}

Example:
User Query: I am driving from Austin to San Antonio, passing through San Marcos, give me good food for each stop.
Call: sort_results(places=get_recommendations(topics=["food"], lat_long=get_latitude_longitude(location="Austin")), sort="rating"); sort_results(places=get_recommendations(topics=["food"], lat_long=get_latitude_longitude(location="San Marcos")), sort="rating"); sort_results(places=get_recommendations(topics=["food"], lat_long=get_latitude_longitude(location="San Antonio")), sort="rating")

Example:
User Query: What's the nearest I need to walk to get some food near Stanford University?
Call: sort_results(places=get_recommendations(topics=["food"], lat_long=get_latitude_longitude(location="Stanford University")), sort="distance")

Example:
User Query: Can you tell me if I should go to San Francisco or San Jose for high end Japanese food?
Call: sort_results(places=get_recommendations(topics=["high-end food", "Japanese"], lat_long=get_latitude_longitude(location="San Francisco")), sort="rating"); sort_results(places=get_recommendations(topics=["high-end food", "Japanese"], lat_long=get_latitude_longitude(location="San Jose")), sort="rating"); get_some_reviews(place_names=sort_results(places=get_recommendations(topics=["high-end food", "Japanese"], lat_long=get_latitude_longitude(location="San Francisco")), sort="rating")); get_some_reviews(place_names=sort_results(places=get_recommendations(topics=["high-end food", "Japanese"], lat_long=get_latitude_longitude(location="San Jose")), sort="rating"))

Example:
User Query: What's this bird app stuff?
Call: out_of_domain(user_query="What's this bird app stuff")

Example:
User Query: What is your political affiliation?
Call: out_of_domain(user_query="What is your political affiliation?")

Example:
User Query: What are people saying about Chipotle in Austin?
Call: get_some_reviews(topics=["Chipotle"], lat_long=get_latitude_longitude(location="Austin"))

Example:
User Query: What are people saying about some of the better Chipotles in Austin?
Call: get_some_reviews(place_names=sort_results(places=get_recommendations(topics=["Chipotles"], lat_long=get_latitude_longitude(location="Austin")), sort="rating"))


User Query: {query}<human_end>

Call:"""

    def __init__(self, tools: Tools) -> None:
        self.tools = tools

        function_definitions = ""
        for function in FUNCTIONS:
            f = getattr(tools, function.name)
            signature = inspect.signature(f)
            docstring = inspect.getdoc(f)

            function_str = self.FUNCTION_DEFINITION_TEMPLATE.format(
                name=function.name, signature=signature, docstring=docstring
            )
            function_definitions += function_str

        self.prompt_without_query = self.PROMPT_TEMPLATE.format(
            function_definitions=function_definitions, query="{query}"
        )

    def get_prompt(self, query: str):
        return self.prompt_without_query.format(query=query)

    def get_function_call_plan(self, function_call_str: str) -> List[str]:
        function_call_list = []
        locals_to_pass = {"function_call_list": function_call_list}
        for f in FUNCTIONS:
            name = f.name
            exec(
                f"def {name}(**_):\n\tfunction_call_list.append('{f.short_description}')",
                locals_to_pass,
            )
        calls = [c.strip() for c in function_call_str.split(";") if c.strip()]
        [eval(call, locals_to_pass) for call in calls]
        return function_call_list

    def run_function_call(self, function_call_str: str):
        function_call_list = []
        locals_to_pass = {"function_call_list": function_call_list, "tools": self.tools}
        for f in FUNCTIONS:
            name = f.name

            locals_to_pass[f"{name}_description_function"] = f.description_function
            locals_to_pass[f"{name}_explanation_function"] = f.explanation_function

            function_definition = f"""
def {name}(**kwargs):
    result = tools.{f.name}(**kwargs)
    function_call_list.append(({name}_description_function(**kwargs), {name}_explanation_function(result)))
    return result
"""
            exec(function_definition, locals_to_pass)

        calls = [c.strip() for c in function_call_str.split(";") if c.strip()]
        for call in calls:
            locals_to_pass["function_call_list"] = function_call_list = []
            result = eval(call, locals_to_pass)
            yield result, function_call_list


class RavenDemo(gr.Blocks):
    def __init__(self, config: DemoConfig) -> None:
        theme = gr.themes.Soft(
            primary_hue=gr.themes.colors.blue,
            secondary_hue=gr.themes.colors.blue,
        )
        super().__init__(theme=theme, css=CSS, title="NexusRaven V2 Demo")

        self.config = config
        self.tools = Tools(config)
        self.functions_helper = FunctionsHelper(self.tools)
        mongo_client = MongoClient(host=config.mongo_endpoint)
        self.collection = mongo_client[config.mongo_collection]["logs"]

        self.raven_client = InferenceClient(
            model=config.raven_endpoint, token=config.hf_token
        )
        self.summary_model_client = InferenceClient(config.summary_model_endpoint)

        self.max_num_steps = 20
        self.function_call_name_set = set([f.name for f in FUNCTIONS])

        with self:
            gr.HTML(HEADER_HTML)
            with gr.Row():
                gr.Image(
                    "NexusRaven.png",
                    show_label=False,
                    show_share_button=True,
                    min_width=200,
                    scale=1,
                )
                with gr.Column(scale=4, min_width=800):
                    gr.Markdown(INTRO_TEXT, elem_classes="inner-large-font")
                    with gr.Row():
                        examples = [
                            gr.Button(query_name) for query_name in EXAMPLE_QUERIES
                        ]

            user_input = gr.Textbox(
                placeholder="Ask anything about places, recommendations, or reviews!",
                show_label=False,
                autofocus=True,
            )
            should_chat = gr.Checkbox(label="Enable Chat Summary", info="If set, summarizes the returned results.", value=True)
            raven_function_call = gr.Code(
                label="🐦‍⬛ NexusRaven V2 13B zero-shot generated function call",
                language="python",
                interactive=False,
                lines=10,
            )
            with gr.Accordion(
                "Executing plan generated by 🐦‍⬛ NexusRaven V2 13B", open=True
            ) as steps_accordion:
                steps = [
                    gr.Textbox(visible=False, show_label=False)
                    for _ in range(self.max_num_steps)
                ]

            with gr.Column():
                initial_relevant_places = self.get_relevant_places([])
                relevant_places = gr.State(initial_relevant_places)
                place_dropdown_choices = self.get_place_dropdown_choices(
                    initial_relevant_places
                )
                places_dropdown = gr.Dropdown(
                    choices=place_dropdown_choices,
                    value=place_dropdown_choices[0],
                    label="Relevant places",
                )
                gmaps_html = gr.HTML(self.get_gmaps_html(initial_relevant_places[0]))

            summary_model_summary = gr.Textbox(
                label="Chat summary",
                interactive=False,
                show_copy_button=True,
                lines=10,
                max_lines=1000,
                autoscroll=False,
                elem_classes="inner-large-font",
            )

            with gr.Accordion("Raven inputs", open=False):
                gr.Textbox(
                    label="Available functions",
                    value="`" + "`, `".join(f.name for f in FUNCTIONS) + "`",
                    interactive=False,
                    show_copy_button=True,
                )
                gr.Textbox(
                    label="Raven prompt",
                    value=self.functions_helper.get_prompt("{query}"),
                    interactive=False,
                    show_copy_button=True,
                    lines=20,
                )

            has_error = gr.State(False)
            user_input.submit(
                fn=self.on_submit,
                inputs=[user_input, should_chat],
                outputs=[
                    user_input,
                    raven_function_call,
                    summary_model_summary,
                    relevant_places,
                    places_dropdown,
                    gmaps_html,
                    steps_accordion,
                    *steps,
                    has_error,
                ],
                concurrency_limit=20,  # not a hyperparameter
                api_name=False,
            ).then(
                self.check_for_error,
                inputs=has_error,
                outputs=[],
            )

            for i, button in enumerate(examples):
                button.click(
                    fn=EXAMPLE_QUERIES.get,
                    inputs=button,
                    outputs=user_input,
                    api_name=f"button_click_{i}",
                )

            places_dropdown.input(
                fn=self.get_gmaps_html_from_dropdown,
                inputs=[places_dropdown, relevant_places],
                outputs=gmaps_html,
            )

    def on_submit(self, query: str, should_chat : bool, request: gr.Request):
        def get_returns():
            return (
                user_input,
                raven_function_call,
                summary_model_summary,
                relevant_places,
                places_dropdown,
                gmaps_html,
                steps_accordion,
                *steps,
                has_error,
            )

        def on_error():
            initial_return[0] = gr.Textbox(interactive=True, autofocus=False)
            initial_return[-1] = True
            return initial_return

        user_input = gr.Textbox(interactive=False)
        raven_function_call = ""
        summary_model_summary = ""
        relevant_places = []
        places_dropdown = ""
        gmaps_html = ""
        steps_accordion = gr.Accordion(open=True)
        steps = [gr.Textbox(value="", visible=False) for _ in range(self.max_num_steps)]
        has_error = False
        initial_return = list(get_returns())
        yield initial_return

        raven_prompt = self.functions_helper.get_prompt(
            query.replace("'", r"\'").replace('"', r"\"")
        )
        print(f"{'-' * 80}\nPrompt sent to Raven\n\n{raven_prompt}\n\n{'-' * 80}\n")
        stream = self.raven_client.text_generation(
            raven_prompt, **RAVEN_GENERATION_KWARGS
        )
        for s in stream:
            for c in s:
                raven_function_call += c
                raven_function_call = raven_function_call.removesuffix("Thought:").removesuffix("<bot_end>")
                yield get_returns()

        raw_raven_response = raven_function_call
        print(f"Raw Raven response before formatting: {raw_raven_response}")

        r_calls = [c.strip() for c in raven_function_call.split(";") if c.strip()]
        f_r_calls = []
        for r_c in r_calls:
            try:
                f_r_call = format_str(r_c.strip(), mode=Mode())
            except:
                yield on_error()
                return

            if not self.whitelist_function_names(f_r_call):
                yield on_error()
                return

            f_r_calls.append(f_r_call)

        raven_function_call = "; ".join(f_r_calls)

        yield get_returns()

        self._set_client_ip(request)
        function_call_plan = self.functions_helper.get_function_call_plan(
            raven_function_call
        )
        for i, v in enumerate(function_call_plan):
            steps[i] = gr.Textbox(value=f"{i+1}. {v}", visible=True)
            yield get_returns()
            sleep(0.1)

        results_gen = self.functions_helper.run_function_call(raven_function_call)
        results = []
        previous_num_calls = 0
        for result, function_call_list in results_gen:
            results.extend(result)
            for i, (description, explanation) in enumerate(function_call_list):
                i = i + previous_num_calls

                if len(description) > 100:
                    description = function_call_plan[i]
                to_stream = f"{i+1}. {description} ..."
                steps[i] = ""
                for c in to_stream:
                    steps[i] += c
                    sleep(0.005)
                    yield get_returns()

                to_stream = "." * randint(0, 5)
                for c in to_stream:
                    steps[i] += c
                    sleep(0.2)
                    yield get_returns()

                to_stream = f" {explanation}"
                for c in to_stream:
                    steps[i] += c
                    sleep(0.005)
                    yield get_returns()

            previous_num_calls += len(function_call_list)

        try:
            relevant_places = self.get_relevant_places(results)
        except:
            relevant_places = self.get_relevant_places([])

        gmaps_html = self.get_gmaps_html(relevant_places[0])
        places_dropdown_choices = self.get_place_dropdown_choices(relevant_places)
        places_dropdown = gr.Dropdown(
            choices=places_dropdown_choices, value=places_dropdown_choices[0]
        )
        steps_accordion = gr.Accordion(open=False)
        yield get_returns()

        while True and should_chat:
            try:
                summary_model_prompt = self.get_summary_model_prompt(results, query)
                print(
                    f"{'-' * 80}\nPrompt sent to summary model\n\n{summary_model_prompt}\n\n{'-' * 80}\n"
                )
                stream = self.summary_model_client.text_generation(
                    summary_model_prompt, **SUMMARY_MODEL_GENERATION_KWARGS
                )
                for s in stream:
                    s = s.removesuffix("</s>")
                    for c in s:
                        summary_model_summary += c
                        summary_model_summary = (
                            summary_model_summary.lstrip().removesuffix(
                                "</s>"
                            )
                        )
                        yield get_returns()
            except huggingface_hub.inference._text_generation.ValidationError:
                if len(results) > 1:
                    new_length = (3 * len(results)) // 4
                    results = results[:new_length]
                    continue
                else:
                    break

            break

        self.collection.insert_one(
            {
                "query": query,
                "raven_output": raw_raven_response,
                "summary_output": summary_model_summary,
            }
        )

        user_input = gr.Textbox(interactive=True, autofocus=False)
        yield get_returns()

    def check_for_error(self, has_error: bool) -> None:
        if has_error:
            raise gr.Error(ERROR_MESSAGE)

    def whitelist_function_names(self, function_call_str: str) -> bool:
        """
        Defensive function name whitelisting inspired by @evan-nexusflow
        """
        for expr in ast.walk(ast.parse(function_call_str)):
            if not isinstance(expr, ast.Call):
                continue

            expr: ast.Call
            function_name = expr.func.id
            if function_name not in self.function_call_name_set:
                return False

        return True

    def get_summary_model_prompt(self, results: List, query: str) -> None:
        # TODO check what outputs are returned and return them properly
        ALLOWED_KEYS = [
            "author_name",
            "text",
            "for_location",
            "time",
            "author_url",
            "language",
            "original_language",
            "name",
            "opening_hours",
            "rating",
            "user_ratings_total",
            "vicinity",
            "distance",
            "formatted_address",
            "price_level",
            "types",
        ]
        ALLOWED_KEYS = set(ALLOWED_KEYS)

        results_str = ""
        for idx, res in enumerate(results):
            if isinstance(res, str):
                results_str += f"{res}\n"
                continue

            assert isinstance(res, dict)

            item_str = ""
            for key, value in res.items():
                if key not in ALLOWED_KEYS:
                    continue

                key = key.replace("_", " ").capitalize()
                item_str += f"\t{key}: {value}\n"

            results_str += f"Result {idx + 1}\n{item_str}\n"

        current_time = datetime.now().strftime("%b %d, %Y %H:%M:%S")
        try:
            current_location = self.tools.get_current_location()[0]
        except:
            current_location = "Current location not found."

        prompt = SUMMARY_MODEL_PROMPT.format(
            current_location=current_location,
            current_time=current_time,
            results=results_str,
            query=query,
        )
        return prompt

    def get_relevant_places(self, results: List) -> List[Tuple[str, str]]:
        """
        Returns
        -------
        relevant_places: List[Tuple[str, str]]
            A list of tuples, where each tuple is (address, name)

        """
        # We use a dict to preserve ordering, while enforcing uniqueness
        relevant_places = dict()
        for result in results:
            if "formatted_address" in result and "name" in result:
                relevant_places[(result["formatted_address"], result["name"])] = None
            elif "formatted_address" in result and "for_location" in result:
                relevant_places[
                    (result["formatted_address"], result["for_location"])
                ] = None
            elif "vicinity" in result and "name" in result:
                relevant_places[(result["vicinity"], result["name"])] = None

        relevant_places = list(relevant_places.keys())

        if not relevant_places:
            current_location = self.tools.get_current_location()[0]
            relevant_places.append((current_location, current_location))

        return relevant_places

    def get_place_dropdown_choices(
        self, relevant_places: List[Tuple[str, str]]
    ) -> List[str]:
        return [p[1] for p in relevant_places]

    def get_gmaps_html(self, relevant_place: Tuple[str, str]) -> str:
        address, name = relevant_place
        return GMAPS_EMBED_HTML_TEMPLATE.format(
            address=quote(address), location=quote(name)
        )

    def get_gmaps_html_from_dropdown(
        self, place_name: str, relevant_places: List[Tuple[str, str]]
    ) -> str:
        relevant_place = [p for p in relevant_places if p[1] == place_name][0]
        return self.get_gmaps_html(relevant_place)

    def _set_client_ip(self, request: gr.Request) -> None:
        client_ip = request.client.host
        if (
            "headers" in request.kwargs
            and "x-forwarded-for" in request.kwargs["headers"]
        ):
            x_forwarded_for = request.kwargs["headers"]["x-forwarded-for"]
        else:
            x_forwarded_for = request.headers.get("x-forwarded-for", None)
        if x_forwarded_for:
            client_ip = x_forwarded_for.split(",")[0].strip()

        self.tools.client_ip = client_ip


demo = RavenDemo(DemoConfig.load_from_env())

if __name__ == "__main__":
    demo.launch(
        share=True,
        allowed_paths=["logo.png", "NexusRaven.png"],
        favicon_path="logo.png",
    )