File size: 5,634 Bytes
3245107
 
 
 
 
 
 
 
 
 
 
3fc700a
3245107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fc700a
3245107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fc700a
 
 
 
 
3245107
 
 
 
 
 
 
 
 
 
 
 
3fc700a
 
3245107
 
 
3fc700a
3245107
 
 
 
 
 
 
 
 
 
3fc700a
3245107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fc700a
3245107
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#-*- coding:utf-8 -*-
import sys, os, re,inspect,json,traceback,logging,argparse, copy
sys.path.append(os.path.realpath(os.path.dirname(inspect.getfile(inspect.currentframe())))+"/../")
from tornado.web import RequestHandler,Application
from tornado.ioloop import IOLoop
from tornado.httpserver import HTTPServer
from tornado.options import define,options
from util import es_conn, setup_logging
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
from nlp import huqie
from nlp import query as Query
from nlp import search
from llm import HuEmbedding, GptTurbo
import numpy as np
from io import BytesIO
from util import config
from timeit import default_timer as timer
from collections import OrderedDict

SE = None
CFIELD="content_ltks"
EMBEDDING = HuEmbedding()
LLM = GptTurbo()

def get_QA_pairs(hists):
    pa = []
    for h in hists:
        for k in ["user", "assistant"]:
            if h.get(k):
                pa.append({
                    "content": h[k],
                    "role": k,
                })

    for p in pa[:-1]: assert len(p) == 2, p
    return pa



def get_instruction(sres, top_i, max_len=8096, fld="content_ltks"):
    max_len //= len(top_i)
    # add instruction to prompt
    instructions = [re.sub(r"[\r\n]+", " ", sres.field[sres.ids[i]][fld]) for i in top_i]
    if len(instructions)>2:
        # Said that LLM is sensitive to the first and the last one, so
        # rearrange the order of references
        instructions.append(copy.deepcopy(instructions[1]))
        instructions.pop(1)

    def token_num(txt):
        c = 0
        for tk in re.split(r"[,。/?‘’”“:;:;!!]", txt):
            if re.match(r"[a-zA-Z-]+$", tk):
                c += 1
                continue
            c += len(tk)
        return c

    _inst = ""
    for ins in instructions:
        if token_num(_inst) > 4096:
            _inst += "\n知识库:" + instructions[-1][:max_len]
            break
        _inst += "\n知识库:" + ins[:max_len]
    return _inst


def prompt_and_answer(history, inst):
    hist = get_QA_pairs(history)
    chks = []
    for s in re.split(r"[::;;。\n\r]+", inst):
        if s: chks.append(s)
    chks = len(set(chks))/(0.1+len(chks))
    print("Duplication portion:", chks)
    
    system = """
你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答%s。当所有知识库内容都与问题无关时,你的回答必须包括"知识库中未找到您要的答案!这是我所知道的,仅作参考。"这句话。回答需要考虑聊天历史。
以下是知识库:
%s
以上是知识库。
"""%((",最好总结成表格" if chks<0.6 and chks>0 else ""), inst)

    print("【PROMPT】:", system)
    start = timer()
    response = LLM.chat(system, hist, {"temperature": 0.2, "max_tokens": 512})
    print("GENERATE: ", timer()-start)
    print("===>>", response)
    return response


class Handler(RequestHandler):
    def post(self):
        global SE,MUST_TK_NUM
        param = json.loads(self.request.body.decode('utf-8'))
        try:
            question = param.get("history",[{"user": "Hi!"}])[-1]["user"]
            res = SE.search({
                    "question": question,
                    "kb_ids": param.get("kb_ids", []),
                    "size": param.get("topn", 15)},
               search.index_name(param["uid"]) 
            )

            sim = SE.rerank(res, question)  
            rk_idx = np.argsort(sim*-1)
            topidx = [i for i in rk_idx if sim[i] >= aram.get("similarity", 0.5)][:param.get("topn",12)]
            inst = get_instruction(res, topidx)

            ans, topidx = prompt_and_answer(param["history"], inst)
            ans = SE.insert_citations(ans, topidx, res)

            refer = OrderedDict()
            docnms = {}
            for i in rk_idx:
                 did = res.field[res.ids[i]]["doc_id"]
                 if did not in docnms: docnms[did] = res.field[res.ids[i]]["docnm_kwd"]
                 if did not in refer: refer[did] = []
                 refer[did].append({
                     "chunk_id": res.ids[i],
                     "content": res.field[res.ids[i]]["content_ltks"],
                     "image": ""
                 })

            print("::::::::::::::", ans)
            self.write(json.dumps({
                "code":0,
                "msg":"success",
                "data":{
                    "uid": param["uid"],
                    "dialog_id": param["dialog_id"],
                    "assistant": ans,
                    "refer": [{
                        "did": did,
                        "doc_name": docnms[did],
                        "chunks": chunks
                    } for did, chunks in refer.items()]
                }
            }))
            logging.info("SUCCESS[%d]"%(res.total)+json.dumps(param, ensure_ascii=False))

        except Exception as e:
            logging.error("Request 500: "+str(e))
            self.write(json.dumps({
                "code":500,
                "msg":str(e),
                "data":{}
            }))
            print(traceback.format_exc())


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--port", default=4455, type=int, help="Port used for service")
    ARGS = parser.parse_args()
    
    SE = search.Dealer(es_conn.HuEs("infiniflow"), EMBEDDING)

    app = Application([(r'/v1/chat/completions', Handler)],debug=False)
    http_server = HTTPServer(app)
    http_server.bind(ARGS.port)
    http_server.start(3)

    IOLoop.current().start()