import json
import logging
import queue
import random
import re
import threading
import uuid
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from copy import deepcopy
from dataclasses import asdict
from typing import Dict, List, Optional

from lagent.actions import ActionExecutor
from lagent.agents import BaseAgent, Internlm2Agent
from lagent.agents.internlm2_agent import Internlm2Protocol
from lagent.schema import AgentReturn, AgentStatusCode, ModelStatusCode
from termcolor import colored

# 初始化日志记录
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class SearcherAgent(Internlm2Agent):

    def __init__(self, template='{query}', **kwargs) -> None:
        super().__init__(**kwargs)
        self.template = template

    def stream_chat(self,
                    question: str,
                    root_question: str = None,
                    parent_response: List[dict] = None,
                    **kwargs) -> AgentReturn:
        message = self.template['input'].format(question=question,
                                                topic=root_question)
        if parent_response:
            if 'context' in self.template:
                parent_response = [
                    self.template['context'].format(**item)
                    for item in parent_response
                ]
                message = '\n'.join(parent_response + [message])
        print(colored(f'current query: {message}', 'green'))
        for agent_return in super().stream_chat(message,
                                                session_id=random.randint(
                                                    0, 999999),
                                                **kwargs):
            agent_return.type = 'searcher'
            agent_return.content = question
            yield deepcopy(agent_return)


class MindSearchProtocol(Internlm2Protocol):

    def __init__(
        self,
        meta_prompt: str = None,
        interpreter_prompt: str = None,
        plugin_prompt: str = None,
        few_shot: Optional[List] = None,
        response_prompt: str = None,
        language: Dict = dict(
            begin='',
            end='',
            belong='assistant',
        ),
        tool: Dict = dict(
            begin='{start_token}{name}\n',
            start_token='<|action_start|>',
            name_map=dict(plugin='<|plugin|>', interpreter='<|interpreter|>'),
            belong='assistant',
            end='<|action_end|>\n',
        ),
        execute: Dict = dict(role='execute',
                             begin='',
                             end='',
                             fallback_role='environment'),
    ) -> None:
        self.response_prompt = response_prompt
        super().__init__(meta_prompt=meta_prompt,
                         interpreter_prompt=interpreter_prompt,
                         plugin_prompt=plugin_prompt,
                         few_shot=few_shot,
                         language=language,
                         tool=tool,
                         execute=execute)

    def format(self,
               inner_step: List[Dict],
               plugin_executor: ActionExecutor = None,
               **kwargs) -> list:
        formatted = []
        if self.meta_prompt:
            formatted.append(dict(role='system', content=self.meta_prompt))
        if self.plugin_prompt:
            plugin_prompt = self.plugin_prompt.format(tool_info=json.dumps(
                plugin_executor.get_actions_info(), ensure_ascii=False))
            formatted.append(
                dict(role='system', content=plugin_prompt, name='plugin'))
        if self.interpreter_prompt:
            formatted.append(
                dict(role='system',
                     content=self.interpreter_prompt,
                     name='interpreter'))
        if self.few_shot:
            for few_shot in self.few_shot:
                formatted += self.format_sub_role(few_shot)
        formatted += self.format_sub_role(inner_step)
        return formatted


class WebSearchGraph:
    end_signal = 'end'
    searcher_cfg = dict()

    def __init__(self):
        self.nodes = {}
        self.adjacency_list = defaultdict(list)
        self.executor = ThreadPoolExecutor(max_workers=10)
        self.future_to_query = dict()
        self.searcher_resp_queue = queue.Queue()

    def add_root_node(self, node_content, node_name='root'):
        self.nodes[node_name] = dict(content=node_content, type='root')
        self.adjacency_list[node_name] = []
        self.searcher_resp_queue.put((node_name, self.nodes[node_name], []))

    def add_node(self, node_name, node_content):
        self.nodes[node_name] = dict(content=node_content, type='searcher')
        self.adjacency_list[node_name] = []

        def model_stream_thread():
            agent = SearcherAgent(**self.searcher_cfg)
            try:
                parent_nodes = []
                for start_node, adj in self.adjacency_list.items():
                    for neighbor in adj:
                        if node_name == neighbor[
                                'name'] and start_node in self.nodes and 'response' in self.nodes[
                                    start_node]:
                            parent_nodes.append(self.nodes[start_node])
                parent_response = [
                    dict(question=node['content'], answer=node['response'])
                    for node in parent_nodes
                ]
                for answer in agent.stream_chat(
                        node_content,
                        self.nodes['root']['content'],
                        parent_response=parent_response):
                    self.searcher_resp_queue.put(
                        deepcopy((node_name,
                                  dict(response=answer.response,
                                       detail=answer), [])))
                self.nodes[node_name]['response'] = answer.response
                self.nodes[node_name]['detail'] = answer
            except Exception as e:
                logger.exception(f'Error in model_stream_thread: {e}')

        self.future_to_query[self.executor.submit(
            model_stream_thread)] = f'{node_name}-{node_content}'

    def add_response_node(self, node_name='response'):
        self.nodes[node_name] = dict(type='end')
        self.searcher_resp_queue.put((node_name, self.nodes[node_name], []))

    def add_edge(self, start_node, end_node):
        self.adjacency_list[start_node].append(
            dict(id=str(uuid.uuid4()), name=end_node, state=2))
        self.searcher_resp_queue.put((start_node, self.nodes[start_node],
                                      self.adjacency_list[start_node]))

    def reset(self):
        self.nodes = {}
        self.adjacency_list = defaultdict(list)

    def node(self, node_name):
        return self.nodes[node_name].copy()


class MindSearchAgent(BaseAgent):

    def __init__(self,
                 llm,
                 searcher_cfg,
                 protocol=MindSearchProtocol(),
                 max_turn=10):
        self.local_dict = {}
        self.ptr = 0
        self.llm = llm
        self.max_turn = max_turn
        WebSearchGraph.searcher_cfg = searcher_cfg
        super().__init__(llm=llm, action_executor=None, protocol=protocol)

    def stream_chat(self, message, **kwargs):
        if isinstance(message, str):
            message = [{'role': 'user', 'content': message}]
        elif isinstance(message, dict):
            message = [message]
        as_dict = kwargs.pop('as_dict', False)
        return_early = kwargs.pop('return_early', False)
        self.local_dict.clear()
        self.ptr = 0
        inner_history = message[:]
        agent_return = AgentReturn()
        agent_return.type = 'planner'
        agent_return.nodes = {}
        agent_return.adjacency_list = {}
        agent_return.inner_steps = deepcopy(inner_history)
        for _ in range(self.max_turn):
            prompt = self._protocol.format(inner_step=inner_history)
            code = None
            for model_state, response, _ in self.llm.stream_chat(
                    prompt, session_id=random.randint(0, 999999), **kwargs):
                if model_state.value < 0:
                    agent_return.state = getattr(AgentStatusCode,
                                                 model_state.name)
                    yield deepcopy(agent_return)
                    return
                response = response.replace('<|plugin|>', '<|interpreter|>')
                _, language, action = self._protocol.parse(response)
                if not language and not action:
                    continue
                code = action['parameters']['command'] if action else ''
                agent_return.state = self._determine_agent_state(
                    model_state, code, agent_return)
                agent_return.response = language if not code else code

                # if agent_return.state == AgentStatusCode.STREAM_ING:
                yield deepcopy(agent_return)

            inner_history.append({'role': 'language', 'content': language})
            print(colored(response, 'blue'))

            if code:
                yield from self._process_code(agent_return, inner_history,
                                              code, as_dict, return_early)
            else:
                agent_return.state = AgentStatusCode.END
                yield deepcopy(agent_return)
                return

        agent_return.state = AgentStatusCode.END
        yield deepcopy(agent_return)

    def _determine_agent_state(self, model_state, code, agent_return):
        if code:
            return (AgentStatusCode.PLUGIN_START if model_state
                    == ModelStatusCode.END else AgentStatusCode.PLUGIN_START)
        return (AgentStatusCode.ANSWER_ING
                if agent_return.nodes and 'response' in agent_return.nodes else
                AgentStatusCode.STREAM_ING)

    def _process_code(self,
                      agent_return,
                      inner_history,
                      code,
                      as_dict=False,
                      return_early=False):
        for node_name, node, adj in self.execute_code(
                code, return_early=return_early):
            if as_dict and 'detail' in node:
                node['detail'] = asdict(node['detail'])
            if not adj:
                agent_return.nodes[node_name] = node
            else:
                agent_return.adjacency_list[node_name] = adj
            # state  1进行中,2未开始,3已结束
            for start_node, neighbors in agent_return.adjacency_list.items():
                for neighbor in neighbors:
                    if neighbor['name'] not in agent_return.nodes:
                        state = 2
                    elif 'detail' not in agent_return.nodes[neighbor['name']]:
                        state = 2
                    elif agent_return.nodes[neighbor['name']][
                            'detail'].state == AgentStatusCode.END:
                        state = 3
                    else:
                        state = 1
                    neighbor['state'] = state
            if not adj:
                yield deepcopy((agent_return, node_name))
        reference, references_url = self._generate_reference(
            agent_return, code, as_dict)
        inner_history.append({
            'role': 'tool',
            'content': code,
            'name': 'plugin'
        })
        inner_history.append({
            'role': 'environment',
            'content': reference,
            'name': 'plugin'
        })
        agent_return.inner_steps = deepcopy(inner_history)
        agent_return.state = AgentStatusCode.PLUGIN_RETURN
        agent_return.references.update(references_url)
        yield deepcopy(agent_return)

    def _generate_reference(self, agent_return, code, as_dict):
        node_list = [
            node.strip().strip('\"') for node in re.findall(
                r'graph\.node\("((?:[^"\\]|\\.)*?)"\)', code)
        ]
        if 'add_response_node' in code:
            return self._protocol.response_prompt, dict()
        references = []
        references_url = dict()
        for node_name in node_list:
            ref_results = None
            ref2url = None
            if as_dict:
                actions = agent_return.nodes[node_name]['detail']['actions']
            else:
                actions = agent_return.nodes[node_name]['detail'].actions
            if actions:
                ref_results = actions[0]['result'][0][
                    'content'] if as_dict else actions[0].result[0]['content']
            if ref_results:
                ref_results = json.loads(ref_results)
                ref2url = {
                    idx: item['url']
                    for idx, item in ref_results.items()
                }

            ref = f"## {node_name}\n\n{agent_return.nodes[node_name]['response']}\n"
            updated_ref = re.sub(
                r'\[\[(\d+)\]\]',
                lambda match: f'[[{int(match.group(1)) + self.ptr}]]', ref)
            numbers = [int(n) for n in re.findall(r'\[\[(\d+)\]\]', ref)]
            if numbers:
                try:
                    assert all(str(elem) in ref2url for elem in numbers)
                except Exception as exc:
                    logger.info(f'Illegal reference id: {str(exc)}')
                if ref2url:
                    references_url.update({
                        str(idx + self.ptr): ref2url[str(idx)]
                        for idx in set(numbers) if str(idx) in ref2url
                    })
                self.ptr += max(numbers) + 1
            references.append(updated_ref)
        return '\n'.join(references), references_url

    def execute_code(self, command: str, return_early=False):

        def extract_code(text: str) -> str:
            text = re.sub(r'from ([\w.]+) import WebSearchGraph', '', text)
            triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
            single_match = re.search(r'`([^`]*)`', text, re.DOTALL)
            if triple_match:
                return triple_match.group(1)
            elif single_match:
                return single_match.group(1)
            return text

        def run_command(cmd):
            try:
                exec(cmd, globals(), self.local_dict)
                plan_graph = self.local_dict.get('graph')
                assert plan_graph is not None
                for future in as_completed(plan_graph.future_to_query):
                    future.result()
                plan_graph.future_to_query.clear()
                plan_graph.searcher_resp_queue.put(plan_graph.end_signal)
            except Exception as e:
                logger.exception(f'Error executing code: {e}')
                raise

        command = extract_code(command)
        producer_thread = threading.Thread(target=run_command,
                                           args=(command, ))
        producer_thread.start()

        responses = defaultdict(list)
        ordered_nodes = []
        active_node = None

        while True:
            try:
                item = self.local_dict.get('graph').searcher_resp_queue.get(
                    timeout=60)
                if item is WebSearchGraph.end_signal:
                    for node_name in ordered_nodes:
                        # resp = None
                        for resp in responses[node_name]:
                            yield deepcopy(resp)
                        # if resp:
                        #     assert resp[1][
                        #         'detail'].state == AgentStatusCode.END
                    break
                node_name, node, adj = item
                if node_name in ['root', 'response']:
                    yield deepcopy((node_name, node, adj))
                else:
                    if node_name not in ordered_nodes:
                        ordered_nodes.append(node_name)
                    responses[node_name].append((node_name, node, adj))
                    if not active_node and ordered_nodes:
                        active_node = ordered_nodes[0]
                    while active_node and responses[active_node]:
                        if return_early:
                            if 'detail' in responses[active_node][-1][
                                    1] and responses[active_node][-1][1][
                                        'detail'].state == AgentStatusCode.END:
                                item = responses[active_node][-1]
                            else:
                                item = responses[active_node].pop(0)
                        else:
                            item = responses[active_node].pop(0)
                        if 'detail' in item[1] and item[1][
                                'detail'].state == AgentStatusCode.END:
                            ordered_nodes.pop(0)
                            responses[active_node].clear()
                            active_node = None
                        yield deepcopy(item)
            except queue.Empty:
                if not producer_thread.is_alive():
                    break
        producer_thread.join()
        return