File size: 5,634 Bytes
3245107 3fc700a 3245107 3fc700a 3245107 3fc700a 3245107 3fc700a 3245107 3fc700a 3245107 3fc700a 3245107 3fc700a 3245107 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
#-*- coding:utf-8 -*-
import sys, os, re,inspect,json,traceback,logging,argparse, copy
sys.path.append(os.path.realpath(os.path.dirname(inspect.getfile(inspect.currentframe())))+"/../")
from tornado.web import RequestHandler,Application
from tornado.ioloop import IOLoop
from tornado.httpserver import HTTPServer
from tornado.options import define,options
from util import es_conn, setup_logging
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
from nlp import huqie
from nlp import query as Query
from nlp import search
from llm import HuEmbedding, GptTurbo
import numpy as np
from io import BytesIO
from util import config
from timeit import default_timer as timer
from collections import OrderedDict
SE = None
CFIELD="content_ltks"
EMBEDDING = HuEmbedding()
LLM = GptTurbo()
def get_QA_pairs(hists):
pa = []
for h in hists:
for k in ["user", "assistant"]:
if h.get(k):
pa.append({
"content": h[k],
"role": k,
})
for p in pa[:-1]: assert len(p) == 2, p
return pa
def get_instruction(sres, top_i, max_len=8096, fld="content_ltks"):
max_len //= len(top_i)
# add instruction to prompt
instructions = [re.sub(r"[\r\n]+", " ", sres.field[sres.ids[i]][fld]) for i in top_i]
if len(instructions)>2:
# Said that LLM is sensitive to the first and the last one, so
# rearrange the order of references
instructions.append(copy.deepcopy(instructions[1]))
instructions.pop(1)
def token_num(txt):
c = 0
for tk in re.split(r"[,。/?‘’”“:;:;!!]", txt):
if re.match(r"[a-zA-Z-]+$", tk):
c += 1
continue
c += len(tk)
return c
_inst = ""
for ins in instructions:
if token_num(_inst) > 4096:
_inst += "\n知识库:" + instructions[-1][:max_len]
break
_inst += "\n知识库:" + ins[:max_len]
return _inst
def prompt_and_answer(history, inst):
hist = get_QA_pairs(history)
chks = []
for s in re.split(r"[::;;。\n\r]+", inst):
if s: chks.append(s)
chks = len(set(chks))/(0.1+len(chks))
print("Duplication portion:", chks)
system = """
你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答%s。当所有知识库内容都与问题无关时,你的回答必须包括"知识库中未找到您要的答案!这是我所知道的,仅作参考。"这句话。回答需要考虑聊天历史。
以下是知识库:
%s
以上是知识库。
"""%((",最好总结成表格" if chks<0.6 and chks>0 else ""), inst)
print("【PROMPT】:", system)
start = timer()
response = LLM.chat(system, hist, {"temperature": 0.2, "max_tokens": 512})
print("GENERATE: ", timer()-start)
print("===>>", response)
return response
class Handler(RequestHandler):
def post(self):
global SE,MUST_TK_NUM
param = json.loads(self.request.body.decode('utf-8'))
try:
question = param.get("history",[{"user": "Hi!"}])[-1]["user"]
res = SE.search({
"question": question,
"kb_ids": param.get("kb_ids", []),
"size": param.get("topn", 15)},
search.index_name(param["uid"])
)
sim = SE.rerank(res, question)
rk_idx = np.argsort(sim*-1)
topidx = [i for i in rk_idx if sim[i] >= aram.get("similarity", 0.5)][:param.get("topn",12)]
inst = get_instruction(res, topidx)
ans, topidx = prompt_and_answer(param["history"], inst)
ans = SE.insert_citations(ans, topidx, res)
refer = OrderedDict()
docnms = {}
for i in rk_idx:
did = res.field[res.ids[i]]["doc_id"]
if did not in docnms: docnms[did] = res.field[res.ids[i]]["docnm_kwd"]
if did not in refer: refer[did] = []
refer[did].append({
"chunk_id": res.ids[i],
"content": res.field[res.ids[i]]["content_ltks"],
"image": ""
})
print("::::::::::::::", ans)
self.write(json.dumps({
"code":0,
"msg":"success",
"data":{
"uid": param["uid"],
"dialog_id": param["dialog_id"],
"assistant": ans,
"refer": [{
"did": did,
"doc_name": docnms[did],
"chunks": chunks
} for did, chunks in refer.items()]
}
}))
logging.info("SUCCESS[%d]"%(res.total)+json.dumps(param, ensure_ascii=False))
except Exception as e:
logging.error("Request 500: "+str(e))
self.write(json.dumps({
"code":500,
"msg":str(e),
"data":{}
}))
print(traceback.format_exc())
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--port", default=4455, type=int, help="Port used for service")
ARGS = parser.parse_args()
SE = search.Dealer(es_conn.HuEs("infiniflow"), EMBEDDING)
app = Application([(r'/v1/chat/completions', Handler)],debug=False)
http_server = HTTPServer(app)
http_server.bind(ARGS.port)
http_server.start(3)
IOLoop.current().start()
|