|
|
|
import sys, os, re,inspect,json,traceback,logging,argparse, copy |
|
sys.path.append(os.path.realpath(os.path.dirname(inspect.getfile(inspect.currentframe())))+"/../") |
|
from tornado.web import RequestHandler,Application |
|
from tornado.ioloop import IOLoop |
|
from tornado.httpserver import HTTPServer |
|
from tornado.options import define,options |
|
from util import es_conn, setup_logging |
|
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity |
|
from nlp import huqie |
|
from nlp import query as Query |
|
from nlp import search |
|
from llm import HuEmbedding, GptTurbo |
|
import numpy as np |
|
from io import BytesIO |
|
from util import config |
|
from timeit import default_timer as timer |
|
from collections import OrderedDict |
|
|
|
SE = None |
|
CFIELD="content_ltks" |
|
EMBEDDING = HuEmbedding() |
|
LLM = GptTurbo() |
|
|
|
def get_QA_pairs(hists): |
|
pa = [] |
|
for h in hists: |
|
for k in ["user", "assistant"]: |
|
if h.get(k): |
|
pa.append({ |
|
"content": h[k], |
|
"role": k, |
|
}) |
|
|
|
for p in pa[:-1]: assert len(p) == 2, p |
|
return pa |
|
|
|
|
|
|
|
def get_instruction(sres, top_i, max_len=8096, fld="content_ltks"): |
|
max_len //= len(top_i) |
|
|
|
instructions = [re.sub(r"[\r\n]+", " ", sres.field[sres.ids[i]][fld]) for i in top_i] |
|
if len(instructions)>2: |
|
|
|
|
|
instructions.append(copy.deepcopy(instructions[1])) |
|
instructions.pop(1) |
|
|
|
def token_num(txt): |
|
c = 0 |
|
for tk in re.split(r"[,。/?‘’”“:;:;!!]", txt): |
|
if re.match(r"[a-zA-Z-]+$", tk): |
|
c += 1 |
|
continue |
|
c += len(tk) |
|
return c |
|
|
|
_inst = "" |
|
for ins in instructions: |
|
if token_num(_inst) > 4096: |
|
_inst += "\n知识库:" + instructions[-1][:max_len] |
|
break |
|
_inst += "\n知识库:" + ins[:max_len] |
|
return _inst |
|
|
|
|
|
def prompt_and_answer(history, inst): |
|
hist = get_QA_pairs(history) |
|
chks = [] |
|
for s in re.split(r"[::;;。\n\r]+", inst): |
|
if s: chks.append(s) |
|
chks = len(set(chks))/(0.1+len(chks)) |
|
print("Duplication portion:", chks) |
|
|
|
system = """ |
|
你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答%s。当所有知识库内容都与问题无关时,你的回答必须包括"知识库中未找到您要的答案!这是我所知道的,仅作参考。"这句话。回答需要考虑聊天历史。 |
|
以下是知识库: |
|
%s |
|
以上是知识库。 |
|
"""%((",最好总结成表格" if chks<0.6 and chks>0 else ""), inst) |
|
|
|
print("【PROMPT】:", system) |
|
start = timer() |
|
response = LLM.chat(system, hist, {"temperature": 0.2, "max_tokens": 512}) |
|
print("GENERATE: ", timer()-start) |
|
print("===>>", response) |
|
return response |
|
|
|
|
|
class Handler(RequestHandler): |
|
def post(self): |
|
global SE,MUST_TK_NUM |
|
param = json.loads(self.request.body.decode('utf-8')) |
|
try: |
|
question = param.get("history",[{"user": "Hi!"}])[-1]["user"] |
|
res = SE.search({ |
|
"question": question, |
|
"kb_ids": param.get("kb_ids", []), |
|
"size": param.get("topn", 15)}, |
|
search.index_name(param["uid"]) |
|
) |
|
|
|
sim = SE.rerank(res, question) |
|
rk_idx = np.argsort(sim*-1) |
|
topidx = [i for i in rk_idx if sim[i] >= aram.get("similarity", 0.5)][:param.get("topn",12)] |
|
inst = get_instruction(res, topidx) |
|
|
|
ans, topidx = prompt_and_answer(param["history"], inst) |
|
ans = SE.insert_citations(ans, topidx, res) |
|
|
|
refer = OrderedDict() |
|
docnms = {} |
|
for i in rk_idx: |
|
did = res.field[res.ids[i]]["doc_id"] |
|
if did not in docnms: docnms[did] = res.field[res.ids[i]]["docnm_kwd"] |
|
if did not in refer: refer[did] = [] |
|
refer[did].append({ |
|
"chunk_id": res.ids[i], |
|
"content": res.field[res.ids[i]]["content_ltks"], |
|
"image": "" |
|
}) |
|
|
|
print("::::::::::::::", ans) |
|
self.write(json.dumps({ |
|
"code":0, |
|
"msg":"success", |
|
"data":{ |
|
"uid": param["uid"], |
|
"dialog_id": param["dialog_id"], |
|
"assistant": ans, |
|
"refer": [{ |
|
"did": did, |
|
"doc_name": docnms[did], |
|
"chunks": chunks |
|
} for did, chunks in refer.items()] |
|
} |
|
})) |
|
logging.info("SUCCESS[%d]"%(res.total)+json.dumps(param, ensure_ascii=False)) |
|
|
|
except Exception as e: |
|
logging.error("Request 500: "+str(e)) |
|
self.write(json.dumps({ |
|
"code":500, |
|
"msg":str(e), |
|
"data":{} |
|
})) |
|
print(traceback.format_exc()) |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--port", default=4455, type=int, help="Port used for service") |
|
ARGS = parser.parse_args() |
|
|
|
SE = search.Dealer(es_conn.HuEs("infiniflow"), EMBEDDING) |
|
|
|
app = Application([(r'/v1/chat/completions', Handler)],debug=False) |
|
http_server = HTTPServer(app) |
|
http_server.bind(ARGS.port) |
|
http_server.start(3) |
|
|
|
IOLoop.current().start() |
|
|
|
|