import argparse
import copy
import json
import os
import random
import sys
import time
import typing
import warnings

import openai
import tiktoken
from loguru import logger
from tenacity import Retrying, _utils, retry_if_not_exception_type
from tenacity.stop import stop_base
from tenacity.wait import wait_base

sys.path.append("..")

from model.crs_model import CRSModel

warnings.filterwarnings("ignore")


def get_exist_dialog_set():
    exist_id_set = set()
    for file in os.listdir(save_dir):
        file_id = os.path.splitext(file)[0]
        exist_id_set.add(file_id)
    return exist_id_set


def my_before_sleep(retry_state):
    logger.debug(
        f"Retrying: attempt {retry_state.attempt_number} ended with: {retry_state.outcome}, spend {retry_state.seconds_since_start} in total"
    )


class my_wait_exponential(wait_base):
    def __init__(
        self,
        multiplier: typing.Union[int, float] = 1,
        max: _utils.time_unit_type = _utils.MAX_WAIT,  # noqa
        exp_base: typing.Union[int, float] = 2,
        min: _utils.time_unit_type = 0,  # noqa
    ) -> None:
        self.multiplier = multiplier
        self.min = _utils.to_seconds(min)
        self.max = _utils.to_seconds(max)
        self.exp_base = exp_base

    def __call__(self, retry_state: "RetryCallState") -> float:
        if retry_state.outcome == openai.error.Timeout:
            return 0

        try:
            exp = self.exp_base ** (retry_state.attempt_number - 1)
            result = self.multiplier * exp
        except OverflowError:
            return self.max
        return max(max(0, self.min), min(result, self.max))


class my_stop_after_attempt(stop_base):
    """Stop when the previous attempt >= max_attempt."""

    def __init__(self, max_attempt_number: int) -> None:
        self.max_attempt_number = max_attempt_number

    def __call__(self, retry_state: "RetryCallState") -> bool:
        if retry_state.outcome == openai.error.Timeout:
            retry_state.attempt_number -= 1
        return retry_state.attempt_number >= self.max_attempt_number


def annotate_completion(prompt, logit_bias=None):
    if logit_bias is None:
        logit_bias = {}

    request_timeout = 20
    for attempt in Retrying(
        reraise=True,
        retry=retry_if_not_exception_type(
            (
                openai.error.InvalidRequestError,
                openai.error.AuthenticationError,
            )
        ),
        wait=my_wait_exponential(min=1, max=60),
        stop=(my_stop_after_attempt(8)),
    ):
        with attempt:
            response = openai.Completion.create(
                model="text-davinci-003",
                prompt=prompt,
                temperature=0,
                max_tokens=128,
                stop="Recommender",
                logit_bias=logit_bias,
                request_timeout=request_timeout,
            )["choices"][0]["text"]
        request_timeout = min(300, request_timeout * 2)

    return response


def get_instruction(dataset):
    if dataset == "redial_eval":
        item_with_year = True
        init_ask_instruction = """To recommend me items that I will accept, you can choose one of the following options.
A: ask my preference for genre
B: ask my preference for actor
C: ask my preference for director
D: I can directly give recommendations
Please enter the option character. Please only response a character."""
        ask_instruction = """To recommend me items that I will accept, you can choose one of the following options.
A: ask my preference for genre
B: ask my preference for actor
C: ask my preference for director
D: I can directly give recommendations
You have selected {}, do not repeat them. Please enter the option character."""
        option2attr = {
            "A": "genre",
            "B": "star",
            "C": "director",
            "D": "recommend",
        }
        option2temaplte = {
            "A": "Which genre do you like?",
            "B": "Which star do you like?",
            "C": "Which director do you like?",
        }
    elif dataset == "opendialkg_eval":
        item_with_year = False
        init_ask_instruction = """To recommend me items that I will accept, you can choose one of the following options.
A: ask my preference for genre
B: ask my preference for actor
C: ask my preference for director
D: ask my preference for writer
E: I can directly give recommendations
Please enter the option character. Please only response a character."""
        ask_instruction = """To recommend me items that I will accept, you can choose one of the following options.
A: ask my preference for genre
B: ask my preference for actor
C: ask my preference for director
D: ask my preference for writer
E: I can directly give recommendations
You have selected {}, do not repeat them. Please enter the option character."""
        option2attr = {
            "A": "genre",
            "B": "actor",
            "C": "director",
            "D": "writer",
            "E": "recommend",
        }
        option2temaplte = {
            "A": "Which genre do you like?",
            "B": "Which actor do you like?",
            "C": "Which director do you like?",
            "D": "Which writer do you like?",
        }
    else:
        raise Exception("do not support this dataset")

    if item_with_year is True:
        rec_instruction = "Please give me 10 recommendations according to my preference (Format: no. title (year if exists). No other things except the movie list in your response)."
    else:
        rec_instruction = "Please give me 10 recommendations according to my preference (Format: no. title. No other things except the item list in your response). You can recommend mentioned items in our dialog."

    return (
        init_ask_instruction,
        ask_instruction,
        rec_instruction,
        option2attr,
        option2temaplte,
    )


def get_model_args(model_name):
    if model_name == "kbrd":
        args_dict = {
            "debug": args.debug,
            "kg_dataset": args.kg_dataset,
            "hidden_size": args.hidden_size,
            "entity_hidden_size": args.entity_hidden_size,
            "num_bases": args.num_bases,
            "rec_model": args.rec_model,
            "conv_model": args.conv_model,
            "context_max_length": args.context_max_length,
            "entity_max_length": args.entity_max_length,
            "tokenizer_path": args.tokenizer_path,
            "encoder_layers": args.encoder_layers,
            "decoder_layers": args.decoder_layers,
            "text_hidden_size": args.text_hidden_size,
            "attn_head": args.attn_head,
            "resp_max_length": args.resp_max_length,
            "seed": args.seed,
        }
    elif model_name == "barcor":
        args_dict = {
            "debug": args.debug,
            "kg_dataset": args.kg_dataset,
            "rec_model": args.rec_model,
            "conv_model": args.conv_model,
            "context_max_length": args.context_max_length,
            "resp_max_length": args.resp_max_length,
            "tokenizer_path": args.tokenizer_path,
            "seed": args.seed,
        }
    elif model_name == "unicrs":
        args_dict = {
            "debug": args.debug,
            "seed": args.seed,
            "kg_dataset": args.kg_dataset,
            "tokenizer_path": args.tokenizer_path,
            "context_max_length": args.context_max_length,
            "entity_max_length": args.entity_max_length,
            "resp_max_length": args.resp_max_length,
            "text_tokenizer_path": args.text_tokenizer_path,
            "rec_model": args.rec_model,
            "conv_model": args.conv_model,
            "model": args.model,
            "num_bases": args.num_bases,
            "text_encoder": args.text_encoder,
        }
    elif model_name == "chatgpt":
        args_dict = {
            "seed": args.seed,
            "debug": args.debug,
            "kg_dataset": args.kg_dataset,
        }

    return args_dict


if __name__ == "__main__":
    local_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
    warnings.filterwarnings("ignore")

    parser = argparse.ArgumentParser()
    parser.add_argument("--api_key")
    parser.add_argument(
        "--dataset", type=str, choices=["redial_eval", "opendialkg_eval"]
    )
    parser.add_argument("--turn_num", type=int, default=5)
    parser.add_argument(
        "--crs_model",
        type=str,
        choices=["kbrd", "barcor", "unicrs", "chatgpt"],
    )

    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--kg_dataset", type=str, choices=["redial", "opendialkg"])

    # model_detailed
    parser.add_argument("--hidden_size", type=int)
    parser.add_argument("--entity_hidden_size", type=int)
    parser.add_argument("--num_bases", type=int, default=8)
    parser.add_argument("--context_max_length", type=int)
    parser.add_argument("--entity_max_length", type=int)

    # model
    parser.add_argument("--rec_model", type=str)
    parser.add_argument("--conv_model", type=str)

    # conv
    parser.add_argument("--tokenizer_path", type=str)
    parser.add_argument("--encoder_layers", type=int)
    parser.add_argument("--decoder_layers", type=int)
    parser.add_argument("--text_hidden_size", type=int)
    parser.add_argument("--attn_head", type=int)
    parser.add_argument("--resp_max_length", type=int)

    # prompt
    parser.add_argument("--model", type=str)
    parser.add_argument("--text_tokenizer_path", type=str)
    parser.add_argument("--text_encoder", type=str)

    args = parser.parse_args()
    openai.api_key = args.api_key
    save_dir = f"../save_{args.turn_num}/ask/{args.crs_model}/{args.dataset}"
    os.makedirs(save_dir, exist_ok=True)

    random.seed(args.seed)

    # recommender
    recommendation_template = "I would recommend the following items:\n\n{}"

    # recommender
    model_args = get_model_args(args.crs_model)
    recommender = CRSModel(crs_model=args.crs_model, **model_args)

    # seeker
    (
        init_ask_instruction,
        ask_instruction,
        rec_instruction,
        option2attr,
        option2template,
    ) = get_instruction(args.dataset)
    options = list(option2attr.keys())

    # scorer
    persuasiveness_template = """Does the explanation make you want to accept the recommendation? Please give your score.
If mention one of [{}], give 2.
Else if you think recommended items are worse than [{}], give 0.
Else if you think recommended items are comparable to [{}] according to the explanation, give 1.
Else if you think recommended items are better than [{}] according to the explanation, give 2.
Only answer the score number."""
    encoding = tiktoken.encoding_for_model("text-davinci-003")
    logit_bias = {encoding.encode(str(score))[0]: 10 for score in range(3)}

    with open(f"../data/{args.kg_dataset}/entity2id.json", "r", encoding="utf-8") as f:
        entity2id = json.load(f)
    id2entity = {}
    for k, v in entity2id.items():
        id2entity[int(v)] = k
    entity_list = list(entity2id.keys())

    name2id = {}

    with open(f"../data/{args.kg_dataset}/id2info.json", "r", encoding="utf-8") as f:
        id2info = json.load(f)

    for k, v in id2info.items():
        name2id[v["name"]] = k

    dialog_id2data = {}
    with open(
        f"../data/{args.dataset}/test_data_processed.jsonl", encoding="utf-8"
    ) as f:
        lines = f.readlines()
        for line in lines:
            line = json.loads(line)
            dialog_id = str(line["dialog_id"]) + "_" + str(line["turn_id"])
            dialog_id2data[dialog_id] = line

    dialog_id_set = set(dialog_id2data.keys()) - get_exist_dialog_set()
    while len(dialog_id_set) > 0:
        print(len(dialog_id_set))
        dialog_id = random.choice(tuple(dialog_id_set))

        data = dialog_id2data[dialog_id]
        conv_dict = copy.deepcopy(data)  # for model
        goal_item_list = [f'"{item}"' for item in conv_dict["rec"]]
        goal_item_str = ", ".join(goal_item_list)
        rec_labels = [name2id[rec] for rec in data["rec"]]

        context_dict = []  # for save
        for i, text in enumerate(conv_dict["context"]):
            if len(text) == 0:
                continue
            if i % 2 == 0:
                role_str = "user"
            else:
                role_str = "assistant"
            context_dict.append({"role": role_str, "content": text})

        # dialog state
        rec_success = False
        asked_options = []
        option2index = {"A": 0, "B": 1, "C": 2, "D": 3, "E": 4}
        if args.kg_dataset == "redial":
            state = [0, 0, 0, 0]
        elif args.kg_dataset == "opendialkg":
            state = [0, 0, 0, 0, 0]

        for i in range(0, args.turn_num):
            # seeker
            # choose option

            if args.crs_model == "chatgpt":
                conv_dict["context"].append(init_ask_instruction)

            # recommender
            # options (list of str): available options, generate one of them
            gen_inputs, recommender_text = recommender.get_conv(conv_dict)
            if args.crs_model != "chatgpt":
                recommender_choose = recommender.get_choice(gen_inputs, options, state)
            else:
                recommender_choose = recommender.get_choice(
                    gen_inputs, options, state, conv_dict
                )
            selected_option = recommender_choose

            if selected_option == options[-1]:  # choose to rec
                # recommender
                rec_items, rec_truth = recommender.get_rec(conv_dict)
                rec_pred = rec_items[0]

                rec_items_str = ""
                for j, rec_item in enumerate(rec_pred[:50]):
                    rec_items_str += f"{i + 1}: {id2entity[rec_item]}\n"
                recommender_text = recommendation_template.format(rec_items_str)

                # judge whether success
                for rec_label in rec_truth:
                    if rec_label in rec_pred:
                        rec_success = True
                        break

                context_dict.append(
                    {
                        "role": "assistant",
                        "content": recommender_text,
                        "rec_items": rec_pred,
                        "rec_success": rec_success,
                        "option": selected_option,
                    }
                )
                conv_dict["context"].append(recommender_text)

                # seeker
                if rec_success is True:
                    seeker_text = "That's perfect, thank you!"
                else:
                    seeker_text = "I don't like them."

                context_dict.append({"role": "user", "content": seeker_text})
                conv_dict["context"].append(seeker_text)

            else:  # choose to ask
                recommender_text = option2template[selected_option]
                context_dict.append(
                    {
                        "role": "assistant",
                        "content": recommender_text,
                        "option": selected_option,
                    }
                )
                conv_dict["context"].append(recommender_text)

                # seeker
                ask_attr = option2attr[selected_option]

                # update state
                state[option2index[selected_option]] = -1e5

                ans_attr_list = []
                for label_id in rec_labels:
                    if str(label_id) in id2info and ask_attr in id2info[str(label_id)]:
                        ans_attr_list.extend(id2info[str(label_id)][ask_attr])
                if len(ans_attr_list) > 0:
                    seeker_text = ", ".join(list(set(ans_attr_list)))
                else:
                    seeker_text = "Sorry, no information about this, please choose another option."

                context_dict.append(
                    {
                        "role": "user",
                        "content": seeker_text,
                        "entity": ans_attr_list,
                    }
                )
                conv_dict["context"].append(seeker_text)
                conv_dict["entity"] += ans_attr_list

            if rec_success is True:
                break

        # score persuasiveness
        # seeker_prompt = ''
        # for turn_dict in context_dict:
        #     if turn_dict['role'] == 'user':
        #         role_str = 'Seeker'
        #     else:
        #         role_str = 'Recommender'
        #     seeker_prompt += f'{role_str}: {turn_dict["content"]}\n'
        # persuasiveness_str = persuasiveness_template.format(goal_item_str, goal_item_str, goal_item_str,
        #                                                          goal_item_str)
        # prompt_str_for_persuasiveness = seeker_prompt + persuasiveness_str
        # prompt_str_for_persuasiveness += '\nSeeker:'

        # persuasiveness_score = annotate_completion(prompt_str_for_persuasiveness, logit_bias).strip()

        # save
        conv_dict["context"] = context_dict
        data["simulator_dialog"] = conv_dict

        with open(f"{save_dir}/{dialog_id}.json", "w", encoding="utf-8") as f:
            json.dump(data, f, ensure_ascii=False, indent=2)

        dialog_id_set -= get_exist_dialog_set()