File size: 2,970 Bytes
f3f614f
dc9e27a
f3f614f
 
dc9e27a
 
 
 
f3f614f
dc9e27a
 
 
 
 
 
 
 
 
 
 
 
 
 
f3f614f
 
 
 
dc9e27a
 
 
 
 
 
f3f614f
dc9e27a
f3f614f
 
dc9e27a
 
 
 
 
 
 
f3f614f
dc9e27a
 
 
 
 
 
 
 
 
 
 
 
 
 
f3f614f
dc9e27a
 
 
f3f614f
 
dc9e27a
 
 
 
 
 
f3f614f
dc9e27a
 
 
 
 
 
 
 
 
f3f614f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
from copy import deepcopy
from datetime import datetime

from lagent.actions import AsyncWebBrowser, WebBrowser
from lagent.agents.stream import get_plugin_prompt
from lagent.prompts import InterpreterParser, PluginParser
from lagent.utils import create_object

from . import models as llm_factory
from .mindsearch_agent import AsyncMindSearchAgent, MindSearchAgent
from .mindsearch_prompt import (
    FINAL_RESPONSE_CN,
    FINAL_RESPONSE_EN,
    GRAPH_PROMPT_CN,
    GRAPH_PROMPT_EN,
    searcher_context_template_cn,
    searcher_context_template_en,
    searcher_input_template_cn,
    searcher_input_template_en,
    searcher_system_prompt_cn,
    searcher_system_prompt_en,
)

LLM = {}


def init_agent(lang="cn",
               model_format="internlm_server",
               search_engine="BingSearch",
               use_async=False):
    mode = "async" if use_async else "sync"
    llm = LLM.get(model_format, {}).get(mode)
    if llm is None:
        llm_cfg = deepcopy(getattr(llm_factory, model_format))
        if llm_cfg is None:
            raise NotImplementedError
        if use_async:
            cls_name = (
                llm_cfg["type"].split(".")[-1] if isinstance(
                    llm_cfg["type"], str) else llm_cfg["type"].__name__)
            llm_cfg["type"] = f"lagent.llms.Async{cls_name}"
        llm = create_object(llm_cfg)
        LLM.setdefault(model_format, {}).setdefault(mode, llm)

    date = datetime.now().strftime("The current date is %Y-%m-%d.")
    plugins = [(dict(
        type=AsyncWebBrowser if use_async else WebBrowser,
        searcher_type=search_engine,
        topk=6,
        secret_id=os.getenv("TENCENT_SEARCH_SECRET_ID"),
        secret_key=os.getenv("TENCENT_SEARCH_SECRET_KEY"),
    ) if search_engine == "TencentSearch" else dict(
        type=AsyncWebBrowser if use_async else WebBrowser,
        searcher_type=search_engine,
        topk=6,
        api_key=os.getenv("WEB_SEARCH_API_KEY"),
    ))]
    agent = (AsyncMindSearchAgent if use_async else MindSearchAgent)(
        llm=llm,
        template=date,
        output_format=InterpreterParser(
            template=GRAPH_PROMPT_CN if lang == "cn" else GRAPH_PROMPT_EN),
        searcher_cfg=dict(
            llm=llm,
            plugins=plugins,
            template=date,
            output_format=PluginParser(
                template=searcher_system_prompt_cn
                if lang == "cn" else searcher_system_prompt_en,
                tool_info=get_plugin_prompt(plugins),
            ),
            user_input_template=(searcher_input_template_cn if lang == "cn"
                                 else searcher_input_template_en),
            user_context_template=(searcher_context_template_cn if lang == "cn"
                                   else searcher_context_template_en),
        ),
        summary_prompt=FINAL_RESPONSE_CN
        if lang == "cn" else FINAL_RESPONSE_EN,
        max_turn=10,
    )
    return agent