Spaces:
Runtime error
Runtime error
File size: 6,531 Bytes
031beb8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import re
import os
import unicodedata
from typing import List
import uuid
import hashlib
import pandas as pd
from common.call_llm import chat_stream_generator
DOC_QA_ENDPOINT = os.environ.get("DOC_QA_ENDPOINT")
prompt_template = """你是由猎户星空开发的AI助手,你的名字叫聚言。你可以根据下面给出的参考资料和聊天历史来回答用户问题。
### 参考资料 ###
{context}
### 聊天历史 ###
{chat_history}
### 用户问题 ###
{question}
### 回答要求 ###
{requirement}
"""
def document_prompt_template():
return """["Source_id": {doc_id},"Content": "{page_content}"]"""
def language_detect(text: str) -> str:
text = re.sub(r"([ ■◼•*…— �●⚫]+|[·\.~•、—'}\n\t]{1,})", '', text.strip())
stats = {
"zh": 0,
"ja": 0,
"ko": 0,
"en": 0,
"th": 0,
"other": 0
}
char_count = 0
for char in text:
try:
code_name = unicodedata.name(char)
except Exception:
continue
char_count += 1
# 判断是否为中文
if 'CJK' in code_name:
stats["zh"] += 1
# 判断是否为日文
elif 'HIRAGANA' in code_name or 'KATAKANA' in code_name:
stats["ja"] += 1
# 判断是否为泰文
elif "THAI" in code_name:
stats["th"] += 1
# 判断是否为韩文
elif 'HANGUL' in code_name:
stats["ko"] += 1
# 判断是否为英文
elif 'LA' in code_name:
stats["en"] += 1
else:
stats["other"] += 1
lang = ""
ratio = 0.0
for lan in stats:
if lan == "other":
continue
# trick: 英文按字母统计不准确,除以4大致表示word个数
if lan == "en":
stats[lan] /= 4.0
lan_r = float(stats[lan]) / char_count
if ratio < lan_r:
lang = lan
ratio = lan_r
return lang
def language_prompt(lan: str) -> str:
_ZH_LANGUAGE_MAP = {
"zh": "中文",
"en": "英文",
"other": "中文",
"ja": "中文",
"zh_gd": "中文",
"ko": "韩文",
"th": "泰文"
}
return _ZH_LANGUAGE_MAP.get(lan.lower(), "中文")
def _get_chat_history(chat_history: List[List]) -> str:
if not chat_history:
return ""
chat_history_text = ""
for human_msg, ai_msg in chat_history:
human = "{'Human': '" + human_msg + "'}"
ai = "{'AI': '" + ai_msg + "'}"
chat_history_text += "[" + ", ".join([human, ai]) + "]\n"
return chat_history_text
def get_prompt(context: str, chat_history: str, question: str, trapped_switch: int, fallback: str,
citations_switch: int) -> str:
answer_prompts = ["1. 你只能根据上面参考资料中给出的事实信息来回答用户问题,不要胡编乱造。",
"2. 如果向用户提出澄清问题有助于回答问题,可以尝试提问。"]
index = 3
if len(fallback) > 0 and trapped_switch == 1:
answer_prompts.append(
str(index) + ". " + """如果参考资料中的信息不足以回答用户问题,请直接回答下面三个双引号中的内容:\"\"\"{fallback}\"\"\"。""".format(
fallback=fallback))
index += 1
if citations_switch:
citation_prompt = "如果你给出的答案里引用了参考资料中的内容,请在答案的结尾处添加你引用的Source_id,引用的Source_id值来自于参考资料中,并用两个方括号括起来。示例:[[d97b811489b73f46c8d2cb1bc888dbbe]]、[[b6be48868de736b90363d001c092c019]]"
answer_prompts.append(str(index) + ". " + citation_prompt)
index += 1
lan = language_detect(question)
style_prompt = """请你以第一人称并且用严谨的风格来回答问题,一定要用{language}来回答,并且基于事实详细阐述。""".format(
language=language_prompt(lan),
)
answer_prompts.append(str(index) + ". " + style_prompt)
answer_prompts = "\n".join(answer_prompts)
prompt = prompt_template.format(context=context, chat_history=chat_history, question=question,
requirement=answer_prompts)
return prompt
def generate_doc_qa(input_text: str, history: List[List[str]], doc_df: "pd.DataFrame", trapped_switch: str, fallback: str,
citations_switch: str):
"""Generates chat responses according to the input text, history and page content."""
# handle input params
print(f"input_text: {input_text}, history: {history}, page_content: {doc_df}, trapped_switch: {trapped_switch}, fallback: {fallback}, citations_switch: {citations_switch}")
citations_switch = 1 if citations_switch == "开启引用" else 0
trapped_switch = 1 if trapped_switch == "自定义话术" else 0
fallback = fallback or ""
input_text = input_text or "你好"
history = (history or [])[-5:] # Keep the last 5 messages in history
doc_df = doc_df[doc_df["文档片段内容"].notna()]
# iterate over all documents
context = ""
source_id_map = dict()
for _, row in doc_df.iterrows():
if not row["文档片段内容"] or not row["文档片段名称"]:
continue
source_id = hashlib.md5(str(uuid.uuid4()).encode("utf-8")).hexdigest()
source_id_map[source_id] = row["文档片段名称"]
context += document_prompt_template().format(doc_id=source_id, page_content=row["文档片段内容"]) + "\n\n"
prompt = get_prompt(context.strip(), _get_chat_history(history), input_text, trapped_switch, fallback,
citations_switch)
print(f"docQA prompt: {prompt}")
messages = [{"role": "user", "content": prompt}]
# append latest message
stream_response = chat_stream_generator(messages=messages, endpoint=DOC_QA_ENDPOINT)
cache = ""
for character in stream_response:
if "[" in character or cache:
cache += character
continue
history[-1][1] += character
yield None, history
if cache:
source_ids = re.findall(r"\[\[(.*?)\]\]", cache)
print(f"Matched source ids {source_ids}")
for source_id in source_ids:
origin_source_id = source_id_map.get(source_id, source_id)
cache = cache.replace(source_id, origin_source_id)
history[-1][1] += cache
yield None, history |