# coding=utf-8
from copy import deepcopy
from .get_dataset import CreateDataset
from .logger import LoggerFactory
from .retrieve_dialog import RetrieveDialog
from .utils import load_json, load_txt, save_to_json

import logging
import os

logger = LoggerFactory.create_logger(name="test", level=logging.INFO)


class GetManualTestSamples:
    def __init__(
        self,
        role_name,
        role_data_path,
        save_samples_dir,
        save_samples_path=None,
        prompt_path="dataset_character.txt",
        max_seq_len=4000,
        retrieve_num=20,
    ):
        self.role_name = role_name.strip()
        self.role_data = load_json(role_data_path)
        self.role_info = self.role_data[0]["role_info"].strip()

        self.prompt = load_txt(prompt_path)
        self.prompt = self.prompt.replace("${role_name}", self.role_name)
        self.prompt = self.prompt.replace("${role_info}",
                                          f"以下是{self.role_name}的人设:\n{self.role_info}\n").strip()

        self.retrieve_num = retrieve_num
        self.retrieve = RetrieveDialog(role_name=self.role_name,
                                       raw_dialog_list=[d["dialog"] for d in self.role_data],
                                       retrieve_num=retrieve_num)

        self.max_seq_len = max_seq_len
        if not save_samples_path:
            save_samples_path = f"{self.role_name}.json"
        self.save_samples_path = os.path.join(save_samples_dir, save_samples_path)

    def _add_simi_dialog(self, history: list, content_length):
        retrieve_results = self.retrieve.get_retrieve_res(history, self.retrieve_num)
        simi_dialogs = deepcopy(retrieve_results)

        if simi_dialogs:
            simi_dialogs = CreateDataset.choose_examples(simi_dialogs,
                                                         max_length=self.max_seq_len - content_length,
                                                         train_flag=False)
        logger.debug(f"retrieve_results: {retrieve_results}\nsimi_dialogs: {simi_dialogs}.")
        return simi_dialogs, retrieve_results

    def get_qa_samples_by_file(self,
                               questions_path,
                               user_name="user",
                               keep_retrieve_results_flag=False
                               ):
        questions = load_txt(questions_path).splitlines()
        samples = []
        for question in questions:
            question = question.replace('\\n', "\n")
            query = f"{user_name}:{question}" if ":" not in question else question
            content = self.prompt.replace("${dialog}", query)
            content = content.replace("${user_name}", user_name).strip()

            history = [query]
            simi_dialogs, retrieve_results = self._add_simi_dialog(history, len(content))

            sample = {
                "role_name": self.role_name,
                "role_info": self.role_info,
                "user_name": user_name,
                "dialog": history,
                "simi_dialogs": simi_dialogs,
            }
            if keep_retrieve_results_flag and retrieve_results:
                sample["retrieve_results"] = retrieve_results
            samples.append(sample)
        self._save_samples(samples)

    def get_qa_samples_by_query(self,
                                questions_query,
                                user_name="user",
                                keep_retrieve_results_flag=False
                                ):
        question = questions_query
        samples = []
        question = question.replace('\\n', "\n")
        query = f"{user_name}: {question}" if ":" not in question else question
        content = self.prompt.replace("${dialog}", query)
        content = content.replace("${user_name}", user_name).strip()

        history = [query]
        simi_dialogs, retrieve_results = self._add_simi_dialog(history, len(content))

        sample = {
            "role_name": self.role_name,
            "role_info": self.role_info,
            "user_name": user_name,
            "dialog": history,
            "simi_dialogs": simi_dialogs,
        }
        if keep_retrieve_results_flag and retrieve_results:
            sample["retrieve_results"] = retrieve_results
        samples.append(sample)
        self._save_samples(samples)

    def _save_samples(self, samples):
        data = samples
        save_to_json(data, self.save_samples_path)


class CreateTestDataset:
    def __init__(self,
                 role_name,
                 role_samples_path=None,
                 role_data_path=None,
                 prompt_path="dataset_character.txt",
                 max_seq_len=4000):
        self.max_seq_len = max_seq_len
        self.role_name = role_name

        self.prompt = load_txt(prompt_path)
        self.prompt = self.prompt.replace("${role_name}", role_name).strip()

        if not role_data_path:
            print("need role_data_path, check please!")
        self.default_simi_dialogs = None
        if os.path.exists(role_data_path):
            data = load_json(role_data_path)
            role_info = data[0]["role_info"]
        else:
            raise ValueError(f"{self.role_name} didn't find role_info.")
        self.role_info = role_info
        self.prompt = self.prompt.replace("${role_info}", f"以下是{self.role_name}的人设:\n{self.role_info}\n").strip()

        if role_samples_path:
            self.role_samples_path = role_samples_path
        else:
            print("check role_samples_path please!")

    def load_samples(self):
        samples = load_json(self.role_samples_path)
        results = []
        for sample in samples:
            input_text = self.prompt

            simi_dialogs = sample.get("simi_dialogs", None)
            if not simi_dialogs:
                simi_dialogs = self.default_simi_dialogs
            if not simi_dialogs:
                raise ValueError(f"didn't find simi_dialogs.")
            simi_dialogs = CreateDataset.choose_examples(simi_dialogs,
                                                         max_length=self.max_seq_len - len(input_text),
                                                         train_flag=False)

            input_text = input_text.replace("${simi_dialog}", simi_dialogs)
            user_name = sample.get("user_name", "user")
            input_text = input_text.replace("${user_name}", user_name)

            dialog = "\n".join(sample["dialog"]) if isinstance(sample["dialog"], list) else sample["dialog"]
            input_text = input_text.replace("${dialog}", dialog)

            assert len(input_text) < self.max_seq_len
            results.append({
                "input_text": input_text,
            })
        return results